From e2d1e61316ced2066aeb8b95fd09809eba40f495 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 10:15:27 +0200
Subject: [PATCH 01/16] Set CI for debugging
---
.github/workflows/tests.yml | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 4231ef227ec..4c21a98f4f9 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -21,7 +21,7 @@ jobs:
check_code_quality:
name: Check code quality
runs-on: ubuntu-latest
- if: github.event.pull_request.draft == false
+ # if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
@@ -36,7 +36,7 @@ jobs:
name: Tests
strategy:
matrix:
- python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
+ python-version: ['3.10']
fail-fast: false
runs-on:
group: aws-g4dn-2xlarge
@@ -46,7 +46,7 @@ jobs:
defaults:
run:
shell: bash
- if: github.event.pull_request.draft == false
+ # if: github.event.pull_request.draft == false
steps:
- name: Git checkout
uses: actions/checkout@v4
From 40666811f4b9f1ba13c28fc80787945016a35857 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 10:17:54 +0200
Subject: [PATCH 02/16] Convert syntax from unittest to pytest
---
tests/slow/test_dpo_slow.py | 8 +-
tests/slow/test_grpo_slow.py | 46 ++-
tests/slow/test_sft_slow.py | 12 +-
tests/test_activation_offloading.py | 12 +-
tests/test_bco_trainer.py | 62 ++--
tests/test_best_of_n_sampler.py | 6 +-
tests/test_callbacks.py | 28 +-
tests/test_cli.py | 4 +-
tests/test_cli_utils.py | 147 ++++-----
tests/test_collators.py | 2 +-
tests/test_core.py | 6 +-
tests/test_cpo_trainer.py | 14 +-
tests/test_data_utils.py | 204 ++++++-------
tests/test_dataset_formatting.py | 74 +++--
tests/test_dpo_trainer.py | 142 ++++-----
tests/test_gkd_trainer.py | 86 +++---
tests/test_grpo_trainer.py | 246 ++++++++-------
tests/test_judges.py | 22 +-
tests/test_kto_trainer.py | 78 +++--
...test_modeling_geometric_mixture_wrapper.py | 14 +-
tests/test_modeling_value_head.py | 104 +++----
tests/test_nash_md_trainer.py | 12 +-
tests/test_online_dpo_trainer.py | 110 +++----
tests/test_orpo_trainer.py | 10 +-
tests/test_peft_models.py | 54 ++--
tests/test_ppo_trainer.py | 8 +-
tests/test_prm_trainer.py | 52 ++--
tests/test_reward_trainer.py | 112 ++++---
tests/test_rewards.py | 12 +-
tests/test_rloo_trainer.py | 157 +++++-----
tests/test_sft_trainer.py | 288 +++++++++---------
tests/test_trainers_args.py | 186 +++++------
tests/test_utils.py | 283 +++++++++--------
tests/test_vllm_client_server.py | 70 +++--
tests/test_xpo_trainer.py | 12 +-
35 files changed, 1276 insertions(+), 1407 deletions(-)
diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py
index 3b76fd8ea07..e24362fbc88 100644
--- a/tests/slow/test_dpo_slow.py
+++ b/tests/slow/test_dpo_slow.py
@@ -151,8 +151,8 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
peft_config=self.peft_config,
)
- self.assertIsInstance(trainer.model, PeftModel)
- self.assertIsNone(trainer.ref_model)
+ assert isinstance(trainer.model, PeftModel)
+ assert trainer.ref_model is None
# train the model
trainer.train()
@@ -215,8 +215,8 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
peft_config=self.peft_config,
)
- self.assertIsInstance(trainer.model, PeftModel)
- self.assertIsNone(trainer.ref_model)
+ assert isinstance(trainer.model, PeftModel)
+ assert trainer.ref_model is None
# train the model
trainer.train()
diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py
index 75dd1dc8d9e..5f4400b9115 100644
--- a/tests/slow/test_grpo_slow.py
+++ b/tests/slow/test_grpo_slow.py
@@ -103,7 +103,7 @@ def test_training_with_liger_grpo_loss(self, model_name):
for n, param in previous_trainable_params.items():
new_param = model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
release_memory(model, trainer)
@@ -153,20 +153,20 @@ def test_training_with_liger_grpo_loss_and_peft(self, model_name):
# Verify PEFT adapter is properly initialized
from peft import PeftModel
- self.assertTrue(isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT")
+ assert isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT"
# Store adapter weights before training
previous_trainable_params = {
n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad
}
- self.assertTrue(len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model")
+ assert len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model"
trainer.train()
# Verify adapter weights have changed after training
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
release_memory(model, trainer)
@@ -199,12 +199,12 @@ def test_training_with_transformers_paged(self, model_name):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
release_memory(model, trainer)
@@ -310,13 +310,13 @@ def reward_func(prompts, completions, **kwargs):
peft_config=lora_config,
)
- self.assertIsInstance(trainer.model, PeftModel)
+ assert isinstance(trainer.model, PeftModel)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that LoRA parameters have changed
# For VLM models, we're more permissive about which parameters can change
@@ -328,7 +328,7 @@ def reward_func(prompts, completions, **kwargs):
lora_params_changed = True
# At least some LoRA parameters should have changed during training
- self.assertTrue(lora_params_changed, "No LoRA parameters were updated during training.")
+ assert lora_params_changed, "No LoRA parameters were updated during training."
except torch.OutOfMemoryError as e:
self.skipTest(f"Skipping VLM training test due to insufficient GPU memory: {e}")
@@ -378,8 +378,8 @@ def test_vlm_processor_vllm_colocate_mode(self):
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct", use_fast=True, padding_side="left")
# Verify processor has both required attributes for VLM detection
- self.assertTrue(hasattr(processor, "tokenizer"))
- self.assertTrue(hasattr(processor, "image_processor"))
+ assert hasattr(processor, "tokenizer")
+ assert hasattr(processor, "image_processor")
def dummy_reward_func(completions, **kwargs):
return [1.0] * len(completions)
@@ -438,17 +438,15 @@ def dummy_reward_func(completions, **kwargs):
)
# Should detect VLM processor correctly and allow vLLM
- self.assertTrue(trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode")
- self.assertEqual(trainer.vllm_mode, "colocate", "Should use colocate mode")
+ assert trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode"
+ assert trainer.vllm_mode == "colocate", "Should use colocate mode"
# Check if signature columns were set properly
if trainer._signature_columns is not None:
# Should include 'image' in signature columns for VLM processors
- self.assertIn(
- "image",
- trainer._signature_columns,
- "Should include 'image' in signature columns for VLM",
- )
+ assert "image" in \
+ trainer._signature_columns, \
+ "Should include 'image' in signature columns for VLM"
# Should not emit any warnings about VLM incompatibility
incompatibility_warnings = [
@@ -457,11 +455,9 @@ def dummy_reward_func(completions, **kwargs):
if "does not support VLMs" in str(w_item.message)
or "not compatible" in str(w_item.message).lower()
]
- self.assertEqual(
- len(incompatibility_warnings),
- 0,
- f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}",
- )
+ assert len(incompatibility_warnings) == \
+ 0, \
+ f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}"
# Test passes if we get this far without exceptions
@@ -525,12 +521,12 @@ def test_training_vllm(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
except Exception as e:
# If vLLM fails to initialize due to hardware constraints or other issues, that's expected
diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py
index db762df107d..7e673a9457c 100755
--- a/tests/slow/test_sft_slow.py
+++ b/tests/slow/test_sft_slow.py
@@ -148,7 +148,7 @@ def test_sft_trainer_peft(self, model_name, packing):
peft_config=self.peft_config,
)
- self.assertIsInstance(trainer.model, PeftModel)
+ assert isinstance(trainer.model, PeftModel)
trainer.train()
@@ -252,7 +252,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
peft_config=self.peft_config,
)
- self.assertIsInstance(trainer.model, PeftModel)
+ assert isinstance(trainer.model, PeftModel)
trainer.train()
@@ -332,7 +332,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
peft_config=self.peft_config,
)
- self.assertIsInstance(trainer.model, PeftModel)
+ assert isinstance(trainer.model, PeftModel)
trainer.train()
@@ -372,7 +372,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
peft_config=self.peft_config,
)
- self.assertIsInstance(trainer.model, PeftModel)
+ assert isinstance(trainer.model, PeftModel)
trainer.train()
@@ -447,11 +447,11 @@ def test_train_offloading(self, model_name, packing):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
release_memory(trainer.model, trainer)
diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py
index a80563005f5..b1e8f59d61d 100644
--- a/tests/test_activation_offloading.py
+++ b/tests/test_activation_offloading.py
@@ -72,10 +72,8 @@ def test_offloading_with_peft_models(self) -> None:
for name_orig, grad_orig in grads_original:
for name_param, param in model.named_parameters():
if name_param == name_orig and param.requires_grad and param.grad is not None:
- self.assertTrue(
- torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5),
- f"Gradient mismatch for {name_orig}",
- )
+ assert torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), \
+ f"Gradient mismatch for {name_orig}"
@require_torch_accelerator
def test_noop_manager_with_offloading(self):
@@ -105,7 +103,7 @@ def test_noop_manager_with_offloading(self):
# Gradients should match as NoOpManager should have prevented offloading
for g1, g2 in zip(grads1, grads2):
- self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5))
+ assert torch.allclose(g1, g2, rtol=1e-4, atol=1e-5)
@require_torch_accelerator
def test_min_offload_size(self):
@@ -152,6 +150,6 @@ def test_real_hf_model(self):
grads2 = [p.grad.clone() for p in model.parameters()]
# Check outputs and gradients match
- self.assertTrue(torch.allclose(out1, out2, rtol=1e-5))
+ assert torch.allclose(out1, out2, rtol=1e-5)
for g1, g2 in zip(grads1, grads2):
- self.assertTrue(torch.allclose(g1, g2, rtol=1e-5))
+ assert torch.allclose(g1, g2, rtol=1e-5)
diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py
index d609c3d5b90..b1fdaf8d8cf 100644
--- a/tests/test_bco_trainer.py
+++ b/tests/test_bco_trainer.py
@@ -26,6 +26,7 @@
from trl.trainer.bco_trainer import _process_tokens, _tokenize
from .testing_utils import TrlTestCase, require_no_wandb, require_sklearn
+import pytest
if is_peft_available():
@@ -71,13 +72,13 @@ def test_train(self, config_name):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))
+ assert not torch.equal(param.cpu(), new_param.cpu())
@require_sklearn
def test_train_with_precompute(self):
@@ -108,13 +109,13 @@ def test_train_with_precompute(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))
+ assert not torch.equal(param.cpu(), new_param.cpu())
@require_sklearn
def test_train_eval(self):
@@ -158,7 +159,7 @@ def test_init_with_ref_model_is_model(self):
report_to="none",
)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
BCOTrainer(
model=model,
ref_model=model, # ref_model can't be the same as model
@@ -196,13 +197,13 @@ def test_tokenize_and_process_tokens(self):
batched=True,
batch_size=2,
)
- self.assertListEqual(tokenized_dataset["prompt"][:], dataset["prompt"][:])
- self.assertListEqual(tokenized_dataset["completion"][:], dataset["completion"][:])
- self.assertListEqual(tokenized_dataset["label"][:], dataset["label"][:])
- self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
- self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
- self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13])
- self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1])
+ assert tokenized_dataset["prompt"][:] == dataset["prompt"][:]
+ assert tokenized_dataset["completion"][:] == dataset["completion"][:]
+ assert tokenized_dataset["label"][:] == dataset["label"][:]
+ assert tokenized_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091]
+ assert tokenized_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1]
+ assert tokenized_dataset["answer_input_ids"][0] == [27261, 13]
+ assert tokenized_dataset["answer_attention_mask"][0] == [1, 1]
fn_kwargs = {
"prefix": "",
@@ -214,14 +215,14 @@ def test_tokenize_and_process_tokens(self):
"max_prompt_length": trainer.max_prompt_length,
}
processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs)
- self.assertListEqual(processed_dataset["prompt"][:], dataset["prompt"][:])
- self.assertListEqual(processed_dataset["completion"][:], dataset["completion"][:])
- self.assertListEqual(processed_dataset["label"][:], dataset["label"][:])
- self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
- self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
- self.assertListEqual(processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645])
- self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
- self.assertListEqual(processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645])
+ assert processed_dataset["prompt"][:] == dataset["prompt"][:]
+ assert processed_dataset["completion"][:] == dataset["completion"][:]
+ assert processed_dataset["label"][:] == dataset["label"][:]
+ assert processed_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091]
+ assert processed_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1]
+ assert processed_dataset["completion_input_ids"][0] == [46518, 374, 2664, 1091, 27261, 13, 151645]
+ assert processed_dataset["completion_attention_mask"][0] == [1, 1, 1, 1, 1, 1, 1]
+ assert processed_dataset["completion_labels"][0] == [-100, -100, -100, -100, 27261, 13, 151645]
@require_sklearn
def test_train_without_providing_ref_model(self):
@@ -249,13 +250,13 @@ def test_train_without_providing_ref_model(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))
+ assert not torch.equal(param.cpu(), new_param.cpu())
@require_sklearn
def test_train_udm(self):
@@ -298,13 +299,13 @@ def embed_prompt(input_ids, attention_mask, model):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))
+ assert not torch.equal(param.cpu(), new_param.cpu())
@require_sklearn
@require_peft
@@ -335,14 +336,14 @@ def test_train_without_providing_ref_model_with_lora(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))
+ assert not torch.equal(param.cpu(), new_param.cpu())
@require_sklearn
@require_no_wandb
@@ -362,11 +363,8 @@ def test_generate_during_eval_no_wandb(self):
report_to="none",
)
- with self.assertRaisesRegex(
- ValueError,
- expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
- " Please install `wandb` or `comet-ml` to resolve.",
- ):
+ with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
+ " Please install `wandb` or `comet-ml` to resolve."):
BCOTrainer(
model=model,
args=training_args,
@@ -440,4 +438,4 @@ def dummy_compute_metrics(*args, **kwargs):
trainer.train()
- self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
+ assert trainer.state.log_history[-2]["eval_test"] == 0.0
diff --git a/tests/test_best_of_n_sampler.py b/tests/test_best_of_n_sampler.py
index 471f75c0c7c..cf6810976ec 100644
--- a/tests/test_best_of_n_sampler.py
+++ b/tests/test_best_of_n_sampler.py
@@ -74,8 +74,8 @@ def test_different_input_types(self):
for q, expected_length in various_queries_formats:
results = best_of_n.generate(q)
- self.assertIsInstance(results, list)
- self.assertEqual(len(results), expected_length)
+ assert isinstance(results, list)
+ assert len(results) == expected_length
def test_different_sample_sizes_and_n_candidates_values(self):
r"""
@@ -110,4 +110,4 @@ def test_different_sample_sizes_and_n_candidates_values(self):
tokenized_queries = [self.tokenizer.encode(query) for query in queries]
results = best_of_n.generate(tokenized_queries)
for result in results:
- self.assertEqual(len(result), expected)
+ assert len(result) == expected
diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py
index 7904a4ae374..501641b5df9 100644
--- a/tests/test_callbacks.py
+++ b/tests/test_callbacks.py
@@ -119,7 +119,7 @@ def test_basic(self):
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
- self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row))
+ assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
def test_without_ref_model(self):
# Same as before, but without the ref_model attribute. It should use the model attribute instead
@@ -145,7 +145,7 @@ def test_without_ref_model(self):
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
- self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row))
+ assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
def test_soft_judge(self):
"""Test that the soft judge functionality works correctly"""
@@ -188,7 +188,7 @@ def test_soft_judge(self):
if "eval_avg_win_prob" in h
]
for history_row, expected_row in zip(winrate_history, expected_soft_winrates):
- self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row))
+ assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
@require_peft
def test_lora(self):
@@ -222,7 +222,7 @@ def test_lora(self):
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
- self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row))
+ assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
class LogCompletionsCallbackTester(TrlTestCase):
@@ -273,12 +273,12 @@ def test_basic_wandb(self):
completions = json.load(f)
# Check that the columns are correct
- self.assertIn("step", completions["columns"])
- self.assertIn("prompt", completions["columns"])
- self.assertIn("completion", completions["columns"])
+ assert "step" in completions["columns"]
+ assert "prompt" in completions["columns"]
+ assert "completion" in completions["columns"]
# Check that the prompt is in the log
- self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0])
+ assert self.dataset["test"][0]["prompt"] in completions["data"][0]
@require_comet
def test_basic_comet(self):
@@ -347,7 +347,7 @@ def test_callback(self):
trainer.train()
last_checkpoint = get_last_checkpoint(self.tmp_dir)
merged_path = os.path.join(last_checkpoint, "merged")
- self.assertTrue(os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint.")
+ assert os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint."
def test_every_checkpoint(self):
training_args = DPOConfig(
@@ -374,7 +374,7 @@ def test_every_checkpoint(self):
for checkpoint in checkpoints:
merged_path = os.path.join(checkpoint, "merged")
- self.assertTrue(os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}.")
+ assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}."
class BEMACallbackTester(TrlTestCase):
@@ -409,7 +409,7 @@ def test_model_saved(self):
# Check that the BEMA model was saved and can be loaded
bema_path = os.path.join(self.tmp_dir, "bema")
- self.assertTrue(os.path.isdir(bema_path), "BEMA directory was not created")
+ assert os.path.isdir(bema_path), "BEMA directory was not created"
AutoModelForCausalLM.from_pretrained(bema_path)
def test_update_frequency_0(self):
@@ -430,7 +430,7 @@ def test_update_frequency_0(self):
# Total 9 steps (17 samples, batch size 8, 3 epochs).
# BEMA starts after step 0 and updates every 2 steps → updates at 2, 4, 5, 8
- self.assertEqual(mock_update.call_args_list, [call(2), call(4), call(6), call(8)])
+ assert mock_update.call_args_list == [call(2), call(4), call(6), call(8)]
def test_update_frequency_1(self):
"""Test that BEMA callback respects the update frequency."""
@@ -450,7 +450,7 @@ def test_update_frequency_1(self):
# Total 9 steps (17 samples, batch size 8, 3 epochs).
# BEMA starts after step 0 and updates every 3 steps → updates at 3, 6, 9
- self.assertEqual(mock_update.call_args_list, [call(3), call(6), call(9)])
+ assert mock_update.call_args_list == [call(3), call(6), call(9)]
def test_update_frequency_2(self):
"""Test that BEMA callback respects the update frequency."""
@@ -470,7 +470,7 @@ def test_update_frequency_2(self):
# Total 9 steps (17 samples, batch size 8, 3 epochs).
# BEMA starts after step 3 and updates every 2 steps → updates at 5, 7, 9
- self.assertEqual(mock_update.call_args_list, [call(5), call(7), call(9)])
+ assert mock_update.call_args_list == [call(5), call(7), call(9)]
def test_no_bema(self):
"""Test that BEMACallback works without BEMA updates."""
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 23b5d6bcff7..638e1d38493 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -51,7 +51,7 @@ def test_env(self, mock_stdout):
command = "trl env"
with patch("sys.argv", command.split(" ")):
main()
- self.assertIn("TRL version: ", mock_stdout.getvalue().strip())
+ assert "TRL version: " in mock_stdout.getvalue().strip()
def test_grpo(self):
from trl.cli import main
@@ -112,7 +112,7 @@ def test_sft_config_file(self):
main()
# Verify that output directory was created
- self.assertTrue(os.path.exists(output_dir))
+ assert os.path.exists(output_dir)
if __name__ == "__main__":
diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py
index 271dd6f5e5b..01417c258bc 100644
--- a/tests/test_cli_utils.py
+++ b/tests/test_cli_utils.py
@@ -23,6 +23,7 @@
from trl.scripts.utils import DatasetConfig
from .testing_utils import TrlTestCase
+import pytest
@dataclass
@@ -40,13 +41,13 @@ class TestTrlParser(TrlTestCase):
def test_init_without_config_field(self):
"""Test initialization without 'config' field in the dataclasses."""
parser = TrlParser(dataclass_types=[MyDataclass])
- self.assertIsInstance(parser, TrlParser)
+ assert isinstance(parser, TrlParser)
def test_init_with_config_field(self):
"""Test initialization with a 'config' field in the dataclass (should raise ValueError)."""
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as context:
TrlParser(dataclass_types=[InvalidDataclass])
- self.assertTrue("has a field named 'config'" in str(context.exception))
+ assert "has a field named 'config'" in str(context.exception)
@patch("builtins.open", mock_open(read_data="env:\n VAR1: value1\n VAR2: value2\narg1: 2"))
@patch("yaml.safe_load")
@@ -67,14 +68,14 @@ def test_parse_args_and_config_with_valid_config(self, mock_environ, mock_yaml_l
mock_environ["VAR2"] = "value2"
# Ensure that the environment variables were set correctly
- self.assertEqual(mock_environ.get("VAR1"), "value1")
- self.assertEqual(mock_environ.get("VAR2"), "value2")
+ assert mock_environ.get("VAR1") == "value1"
+ assert mock_environ.get("VAR2") == "value2"
# Check the parsed arguments
- self.assertEqual(len(result_args), 1)
- self.assertIsInstance(result_args[0], MyDataclass)
- self.assertEqual(result_args[0].arg1, 2)
- self.assertEqual(result_args[0].arg2, "value")
+ assert len(result_args) == 1
+ assert isinstance(result_args[0], MyDataclass)
+ assert result_args[0].arg1 == 2
+ assert result_args[0].arg2 == "value"
@patch("builtins.open", mock_open(read_data="arg1: 2"))
@patch("yaml.safe_load")
@@ -90,9 +91,9 @@ def test_parse_args_and_arg_override_config(self, mock_yaml_load):
result_args = parser.parse_args_and_config(args)
# Check the parsed arguments
- self.assertEqual(len(result_args), 1)
- self.assertIsInstance(result_args[0], MyDataclass)
- self.assertEqual(result_args[0].arg1, 3)
+ assert len(result_args) == 1
+ assert isinstance(result_args[0], MyDataclass)
+ assert result_args[0].arg1 == 3
@patch("builtins.open", mock_open(read_data="env: not_a_dict"))
@patch("yaml.safe_load")
@@ -104,10 +105,10 @@ def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load):
args = ["--arg1", "2", "--arg2", "value", "--config", "config.yaml"]
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as context:
parser.parse_args_and_config(args)
- self.assertEqual(str(context.exception), "`env` field should be a dict in the YAML file.")
+ assert str(context.exception) == "`env` field should be a dict in the YAML file."
def test_parse_args_and_config_without_config(self):
"""Test parse_args_and_config without the `--config` argument."""
@@ -119,10 +120,10 @@ def test_parse_args_and_config_without_config(self):
result_args = parser.parse_args_and_config(args)
# Check that the arguments are parsed as is
- self.assertEqual(len(result_args), 1)
- self.assertIsInstance(result_args[0], MyDataclass)
- self.assertEqual(result_args[0].arg1, 2)
- self.assertEqual(result_args[0].arg2, "value")
+ assert len(result_args) == 1
+ assert isinstance(result_args[0], MyDataclass)
+ assert result_args[0].arg1 == 2
+ assert result_args[0].arg2 == "value"
def test_set_defaults_with_config(self):
"""Test set_defaults_with_config updates the defaults."""
@@ -133,9 +134,9 @@ def test_set_defaults_with_config(self):
# Ensure the default value is updated
result_args = parser.parse_args_and_config([])
- self.assertEqual(len(result_args), 1)
- self.assertIsInstance(result_args[0], MyDataclass)
- self.assertEqual(result_args[0].arg1, 42)
+ assert len(result_args) == 1
+ assert isinstance(result_args[0], MyDataclass)
+ assert result_args[0].arg1 == 42
def test_parse_args_and_config_with_remaining_strings(self):
parser = TrlParser(dataclass_types=[MyDataclass])
@@ -146,11 +147,11 @@ def test_parse_args_and_config_with_remaining_strings(self):
result_args = parser.parse_args_and_config(args, return_remaining_strings=True)
# Check that the arguments are parsed as is
- self.assertEqual(len(result_args), 2)
- self.assertIsInstance(result_args[0], MyDataclass)
- self.assertEqual(result_args[0].arg1, 2)
- self.assertEqual(result_args[0].arg2, "value")
- self.assertEqual(result_args[1], ["remaining"])
+ assert len(result_args) == 2
+ assert isinstance(result_args[0], MyDataclass)
+ assert result_args[0].arg1 == 2
+ assert result_args[0].arg2 == "value"
+ assert result_args[1] == ["remaining"]
@patch("builtins.open", mock_open(read_data="remaining_string_in_config: abc"))
@patch("yaml.safe_load")
@@ -165,10 +166,10 @@ def test_parse_args_and_config_with_remaining_strings_in_config_and_args(self, m
result_args = parser.parse_args_and_config(args, return_remaining_strings=True)
# Check that the arguments are parsed as is
- self.assertEqual(len(result_args), 2)
- self.assertIsInstance(result_args[0], MyDataclass)
- self.assertEqual(result_args[0].arg1, 2)
- self.assertEqual(result_args[1], ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"])
+ assert len(result_args) == 2
+ assert isinstance(result_args[0], MyDataclass)
+ assert result_args[0].arg1 == 2
+ assert result_args[1] == ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"]
@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
@patch("yaml.safe_load")
@@ -190,11 +191,11 @@ def test_subparsers_with_config_defaults(self, mock_yaml_load):
result_args = parser.parse_args_and_config(args)
# Check main parser arguments
- self.assertEqual(len(result_args), 1)
+ assert len(result_args) == 1
# Check that config values were applied to the subparser
- self.assertEqual(result_args[0].arg1, 2) # Default from config
- self.assertEqual(result_args[0].arg2, "config_value") # Default from config
+ assert result_args[0].arg1 == 2 # Default from config
+ assert result_args[0].arg2 == "config_value" # Default from config
@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
@patch("yaml.safe_load")
@@ -216,8 +217,8 @@ def test_subparsers_with_config_defaults_and_arg_override(self, mock_yaml_load):
result_args = parser.parse_args_and_config(args)
# Command line arguments should override config
- self.assertEqual(result_args[0].arg1, 3)
- self.assertEqual(result_args[0].arg2, "config_value") # Still from config
+ assert result_args[0].arg1 == 3
+ assert result_args[0].arg2 == "config_value" # Still from config
@patch("builtins.open", mock_open(read_data="arg1: 2\nthis_arg_does_not_exist: config_value"))
@patch("yaml.safe_load")
@@ -236,7 +237,7 @@ def test_subparsers_with_config_defaults_and_arg_override_wrong_name(self, mock_
# Test with command line arguments overriding config
args = ["subcommand", "--arg1", "3", "--config", "config.yaml"]
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
parser.parse_args_and_config(args)
parser.parse_args_and_config(args, fail_with_unknown_args=False)
@@ -263,11 +264,11 @@ def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load):
result_args = parser.parse_args_and_config(args)
# Check main parser arguments
- self.assertEqual(len(result_args), 1)
+ assert len(result_args) == 1
# Check that config values were applied to the subparser
- self.assertEqual(result_args[0].arg1, 2) # Default from config
- self.assertEqual(result_args[0].arg2, "config_value") # Default from config
+ assert result_args[0].arg1 == 2 # Default from config
+ assert result_args[0].arg2 == "config_value" # Default from config
class TestGetDataset(unittest.TestCase):
@@ -277,7 +278,7 @@ def test_single_dataset_with_config(self):
)
result = get_dataset(mixture_config)
expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
- self.assertEqual(expected["train"][:], result["train"][:])
+ assert expected["train"][:] == result["train"][:]
def test_single_dataset_preference_config(self):
mixture_config = DatasetMixtureConfig(
@@ -285,7 +286,7 @@ def test_single_dataset_preference_config(self):
)
result = get_dataset(mixture_config)
expected = load_dataset("trl-internal-testing/zen", "standard_preference")
- self.assertEqual(expected["train"][:], result["train"][:])
+ assert expected["train"][:] == result["train"][:]
def test_single_dataset_streaming(self):
mixture_config = DatasetMixtureConfig(
@@ -294,7 +295,7 @@ def test_single_dataset_streaming(self):
)
result = get_dataset(mixture_config)
expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
- self.assertEqual(expected["train"].to_list(), list(result["train"]))
+ assert expected["train"].to_list() == list(result["train"])
def test_dataset_mixture_basic(self):
dataset_config1 = DatasetConfig(
@@ -305,15 +306,15 @@ def test_dataset_mixture_basic(self):
)
mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2])
result = get_dataset(mixture_config)
- self.assertIsInstance(result, DatasetDict)
- self.assertIn("train", result)
+ assert isinstance(result, DatasetDict)
+ assert "train" in result
train_dataset = result["train"]
- self.assertEqual(train_dataset.column_names, ["prompt"])
+ assert train_dataset.column_names == ["prompt"]
prompts = train_dataset["prompt"]
expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
- self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"])
+ assert prompts[: len(prompts) // 2] == expected_first_half["prompt"]
expected_second_half = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train")
- self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"])
+ assert prompts[len(prompts) // 2 :] == expected_second_half["prompt"]
def test_dataset_mixture_with_weights(self):
dataset_config1 = DatasetConfig(
@@ -324,17 +325,17 @@ def test_dataset_mixture_with_weights(self):
)
mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2])
result = get_dataset(mixture_config)
- self.assertIsInstance(result, DatasetDict)
- self.assertIn("train", result)
+ assert isinstance(result, DatasetDict)
+ assert "train" in result
train_dataset = result["train"]
- self.assertEqual(train_dataset.column_names, ["prompt"])
+ assert train_dataset.column_names == ["prompt"]
prompts = train_dataset["prompt"]
expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train[:50%]")
- self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"])
+ assert prompts[: len(prompts) // 2] == expected_first_half["prompt"]
expected_second_half = load_dataset(
"trl-internal-testing/zen", "standard_prompt_completion", split="train[:50%]"
)
- self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"])
+ assert prompts[len(prompts) // 2 :] == expected_second_half["prompt"]
def test_dataset_mixture_with_test_split(self):
mixture_config = DatasetMixtureConfig(
@@ -342,19 +343,19 @@ def test_dataset_mixture_with_test_split(self):
test_split_size=2,
)
result = get_dataset(mixture_config)
- self.assertIsInstance(result, DatasetDict)
- self.assertIn("train", result)
- self.assertIn("test", result)
- self.assertEqual(len(result["train"]), 15)
- self.assertEqual(len(result["test"]), 2)
+ assert isinstance(result, DatasetDict)
+ assert "train" in result
+ assert "test" in result
+ assert len(result["train"]) == 15
+ assert len(result["test"]) == 2
def test_empty_dataset_mixture_raises_error(self):
mixture_config = DatasetMixtureConfig(datasets=[])
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as context:
get_dataset(mixture_config)
- self.assertIn("No datasets were loaded", str(context.exception))
+ assert "No datasets were loaded" in str(context.exception)
def test_mixture_multiple_different_configs(self):
dataset_config1 = DatasetConfig(
@@ -365,9 +366,9 @@ def test_mixture_multiple_different_configs(self):
)
mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2])
result = get_dataset(mixture_config)
- self.assertIsInstance(result, DatasetDict)
- self.assertIn("train", result)
- self.assertGreater(len(result["train"]), 0)
+ assert isinstance(result, DatasetDict)
+ assert "train" in result
+ assert len(result["train"]) > 0
def test_trlparser_parses_yaml_config_correctly(self):
# Prepare YAML content exactly like your example
@@ -390,24 +391,24 @@ def test_trlparser_parses_yaml_config_correctly(self):
args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0]
# Assert that we got DatasetMixtureConfig instance
- self.assertIsInstance(args, DatasetMixtureConfig)
+ assert isinstance(args, DatasetMixtureConfig)
# Assert datasets list length
- self.assertEqual(len(args.datasets), 2)
+ assert len(args.datasets) == 2
# Check first dataset
dataset_config1 = args.datasets[0]
- self.assertIsInstance(dataset_config1, DatasetConfig)
- self.assertEqual(dataset_config1.path, "trl-internal-testing/zen")
- self.assertEqual(dataset_config1.name, "standard_prompt_only")
- self.assertIsNone(dataset_config1.columns) # No columns specified
+ assert isinstance(dataset_config1, DatasetConfig)
+ assert dataset_config1.path == "trl-internal-testing/zen"
+ assert dataset_config1.name == "standard_prompt_only"
+ assert dataset_config1.columns is None # No columns specified
# Check second dataset
dataset_config2 = args.datasets[1]
- self.assertIsInstance(dataset_config2, DatasetConfig)
- self.assertEqual(dataset_config2.path, "trl-internal-testing/zen")
- self.assertEqual(dataset_config2.name, "standard_preference")
- self.assertEqual(dataset_config2.columns, ["prompt"]) # Columns specified
+ assert isinstance(dataset_config2, DatasetConfig)
+ assert dataset_config2.path == "trl-internal-testing/zen"
+ assert dataset_config2.name == "standard_preference"
+ assert dataset_config2.columns == ["prompt"] # Columns specified
def test_trlparser_parses_yaml_and_loads_dataset(self):
# Prepare YAML content exactly like your example
@@ -428,4 +429,4 @@ def test_trlparser_parses_yaml_and_loads_dataset(self):
# Load the dataset using get_dataset
result = get_dataset(args)
expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
- self.assertEqual(expected["train"][:], result["train"][:])
+ assert expected["train"][:] == result["train"][:]
diff --git a/tests/test_collators.py b/tests/test_collators.py
index b578758f027..d798f29d54d 100644
--- a/tests/test_collators.py
+++ b/tests/test_collators.py
@@ -26,7 +26,7 @@ def setUp(self):
self.collator = DataCollatorForPreference(pad_token_id=0)
def assertTensorEqual(self, tensor1, tensor2):
- self.assertTrue(torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}")
+ assert torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}"
def test_padding_behavior(self):
examples = [
diff --git a/tests/test_core.py b/tests/test_core.py
index bab69ca9da2..78e57c64b30 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -32,10 +32,10 @@ def setUp(self):
self.test_input_unmasked = self.test_input[1:3]
def test_masked_mean(self):
- self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask))
+ assert torch.mean(self.test_input_unmasked) == masked_mean(self.test_input, self.test_mask)
def test_masked_var(self):
- self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask))
+ assert torch.var(self.test_input_unmasked) == masked_var(self.test_input, self.test_mask)
def test_masked_whiten(self):
def whiten(values: torch.Tensor) -> torch.Tensor:
@@ -45,4 +45,4 @@ def whiten(values: torch.Tensor) -> torch.Tensor:
whiten_unmasked = whiten(self.test_input_unmasked)
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
diffs = (whiten_unmasked - whiten_masked).sum()
- self.assertLess(abs(diffs.item()), 0.00001)
+ assert abs(diffs.item()) < 0.00001
diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py
index cc3e394846d..c0edd771c5e 100644
--- a/tests/test_cpo_trainer.py
+++ b/tests/test_cpo_trainer.py
@@ -87,13 +87,13 @@ def test_cpo_trainer(self, name, loss_type, config_name):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
@parameterized.expand(
[
@@ -143,14 +143,14 @@ def test_cpo_trainer_with_lora(self, config_name):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
def test_compute_metrics(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
@@ -180,7 +180,7 @@ def dummy_compute_metrics(*args, **kwargs):
trainer.train()
- self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
+ assert trainer.state.log_history[-2]["eval_test"] == 0.0
def test_alphapo_trainer(self):
training_args = CPOConfig(
@@ -212,9 +212,9 @@ def test_alphapo_trainer(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0:
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py
index 0a9eba7f7bb..88c63868094 100644
--- a/tests/test_data_utils.py
+++ b/tests/test_data_utils.py
@@ -55,7 +55,7 @@ def test_basic_user_assistant_conversation(self):
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
]
- self.assertEqual(messages, expected)
+ assert messages == expected
def test_first_user_message_gets_image(self):
"""Test that only the first user message gets an image placeholder."""
@@ -73,7 +73,7 @@ def test_first_user_message_gets_image(self):
{"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]},
]
- self.assertEqual(messages, expected)
+ assert messages == expected
def test_multiple_images(self):
"""Test that multiple images are added to the first user message."""
@@ -97,7 +97,7 @@ def test_multiple_images(self):
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
]
- self.assertEqual(messages, expected)
+ assert messages == expected
def test_system_message_transformation(self):
"""Test that system messages are properly transformed."""
@@ -113,7 +113,7 @@ def test_system_message_transformation(self):
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
]
- self.assertEqual(messages, expected)
+ assert messages == expected
def test_already_prepared_messages_unchanged(self):
"""Test that messages with list content are not modified."""
@@ -126,7 +126,7 @@ def test_already_prepared_messages_unchanged(self):
original = copy.deepcopy(messages)
prepare_multimodal_messages(messages, num_images=1)
- self.assertEqual(messages, original)
+ assert messages == original
def test_mixed_prepared_and_unprepared_messages(self):
"""Test handling of mixed prepared and unprepared messages."""
@@ -144,7 +144,7 @@ def test_mixed_prepared_and_unprepared_messages(self):
{"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]},
]
- self.assertEqual(messages, expected)
+ assert messages == expected
class IsConversationalTester(TrlTestCase):
@@ -250,11 +250,11 @@ class IsConversationalTester(TrlTestCase):
@parameterized.expand(itertools.product(conversational_examples))
def test_conversational(self, example):
- self.assertTrue(is_conversational(example))
+ assert is_conversational(example)
@parameterized.expand(itertools.product(non_conversational_examples))
def test_non_conversational(self, example):
- self.assertFalse(is_conversational(example))
+ assert not is_conversational(example)
class IsConversationalFromValueTester(TrlTestCase):
@@ -265,7 +265,7 @@ def test_positive_1(self):
{"from": "assistant", "value": "It is blue."},
],
}
- self.assertTrue(is_conversational_from_value(example))
+ assert is_conversational_from_value(example)
def test_negative_1(self):
example = {
@@ -274,11 +274,11 @@ def test_negative_1(self):
{"role": "assistant", "content": "It is blue."},
],
}
- self.assertFalse(is_conversational_from_value(example))
+ assert not is_conversational_from_value(example)
def test_negative_2(self):
example = {"text": "The sky is blue."}
- self.assertFalse(is_conversational_from_value(example))
+ assert not is_conversational_from_value(example)
class ApplyChatTemplateTester(TrlTestCase):
@@ -352,24 +352,24 @@ def test_apply_chat_template(self, tokenizer_id, example):
result = apply_chat_template(example, tokenizer)
# Checking if the result is a dictionary
- self.assertIsInstance(result, dict)
+ assert isinstance(result, dict)
# The chat template should be applied to the following keys
for key in ["prompt", "chosen", "rejected", "completion"]:
if key in example:
- self.assertIn(key, result)
- self.assertIsInstance(result[key], str)
+ assert key in result
+ assert isinstance(result[key], str)
# Exception for messages, the key is "text" once the chat template is applied
if "messages" in example:
- self.assertIn("text", result)
- self.assertIsInstance(result["text"], str)
+ assert "text" in result
+ assert isinstance(result["text"], str)
# The label should be kept
if "label" in example:
- self.assertIn("label", result)
- self.assertIsInstance(result["label"], bool)
- self.assertEqual(result["label"], example["label"])
+ assert "label" in result
+ assert isinstance(result["label"], bool)
+ assert result["label"] == example["label"]
# both conversational and non-conversational examples
@parameterized.expand(itertools.product(tokenizers, conversational_examples + non_conversational_examples))
@@ -378,24 +378,24 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example):
result = maybe_apply_chat_template(example, tokenizer)
# Checking if the result is a dictionary
- self.assertIsInstance(result, dict)
+ assert isinstance(result, dict)
# The chat template should be applied to the following keys
for key in ["prompt", "chosen", "rejected", "completion"]:
if key in example:
- self.assertIn(key, result)
- self.assertIsInstance(result[key], str)
+ assert key in result
+ assert isinstance(result[key], str)
# Exception for messages, the key is "text" once the chat template is applied
if "messages" in example:
- self.assertIn("text", result)
- self.assertIsInstance(result["text"], str)
+ assert "text" in result
+ assert isinstance(result["text"], str)
# The label should be kept
if "label" in example:
- self.assertIn("label", result)
- self.assertIsInstance(result["label"], bool)
- self.assertEqual(result["label"], example["label"])
+ assert "label" in result
+ assert isinstance(result["label"], bool)
+ assert result["label"] == example["label"]
def test_apply_chat_template_with_tools(self):
tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2")
@@ -420,13 +420,13 @@ def get_current_temperature(location: str):
result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature])
# Verify tools are included in the output
- self.assertIn("get_current_temperature", result_with_tools["prompt"])
+ assert "get_current_temperature" in result_with_tools["prompt"]
# Test without tools
result_without_tools = apply_chat_template(test_case, tokenizer, tools=None)
# Verify tools are not included in the output
- self.assertNotIn("get_current_temperature", result_without_tools["prompt"])
+ assert "get_current_temperature" not in result_without_tools["prompt"]
class ApplyChatTemplateHarmonyTester(TrlTestCase):
@@ -459,7 +459,7 @@ def test_language_modeling(self):
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""")
- self.assertEqual(output["text"], expected)
+ assert output["text"] == expected
def test_prompt_only(self):
messages = {
@@ -489,7 +489,7 @@ def test_prompt_only(self):
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
- self.assertEqual(output["prompt"], expected)
+ assert output["prompt"] == expected
def test_prompt_completion(self):
messages = {
@@ -523,8 +523,8 @@ def test_prompt_completion(self):
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>"
- self.assertEqual(output["prompt"], expected_prompt)
- self.assertEqual(output["completion"], expected_completion)
+ assert output["prompt"] == expected_prompt
+ assert output["completion"] == expected_completion
def test_preference(self):
messages = {
@@ -562,9 +562,9 @@ def test_preference(self):
expected_chosen = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>"
expected_rejected = "<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>"
- self.assertEqual(output["prompt"], expected_prompt)
- self.assertEqual(output["chosen"], expected_chosen)
- self.assertEqual(output["rejected"], expected_rejected)
+ assert output["prompt"] == expected_prompt
+ assert output["chosen"] == expected_chosen
+ assert output["rejected"] == expected_rejected
def test_preference_with_implicit_prompt(self):
messages = {
@@ -614,8 +614,8 @@ def test_preference_with_implicit_prompt(self):
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>""")
- self.assertEqual(output["chosen"], expected_chosen)
- self.assertEqual(output["rejected"], expected_rejected)
+ assert output["chosen"] == expected_chosen
+ assert output["rejected"] == expected_rejected
def test_unpaired_preference(self):
messages = {
@@ -650,9 +650,9 @@ def test_unpaired_preference(self):
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>"
- self.assertEqual(output["prompt"], expected_prompt)
- self.assertEqual(output["completion"], expected_completion)
- self.assertTrue(output["label"])
+ assert output["prompt"] == expected_prompt
+ assert output["completion"] == expected_completion
+ assert output["label"]
class UnpairPreferenceDatasetTester(TrlTestCase):
@@ -675,58 +675,46 @@ class UnpairPreferenceDatasetTester(TrlTestCase):
def test_unpair_preference_dataset(self):
# Test that a paired dataset is correctly converted to unpaired
unpaired_dataset = unpair_preference_dataset(self.paired_dataset)
- self.assertEqual(
- unpaired_dataset.to_dict(),
- self.unpaired_dataset.to_dict(),
- "The paired dataset should be converted to unpaired.",
- )
+ assert unpaired_dataset.to_dict() == \
+ self.unpaired_dataset.to_dict(), \
+ "The paired dataset should be converted to unpaired."
def test_unpair_preference_dataset_dict(self):
# Test that a paired dataset dict is correctly converted to unpaired
paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict)
- self.assertEqual(
- unpaired_dataset_dict["abc"].to_dict(),
- self.unpaired_dataset.to_dict(),
- "The paired dataset should be converted to unpaired.",
- )
+ assert unpaired_dataset_dict["abc"].to_dict() == \
+ self.unpaired_dataset.to_dict(), \
+ "The paired dataset should be converted to unpaired."
def test_maybe_unpair_preference_dataset(self):
# Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset
unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset)
- self.assertEqual(
- unpaired_dataset.to_dict(),
- self.unpaired_dataset.to_dict(),
- "The paired dataset should be converted to unpaired.",
- )
+ assert unpaired_dataset.to_dict() == \
+ self.unpaired_dataset.to_dict(), \
+ "The paired dataset should be converted to unpaired."
def test_maybe_unpair_preference_dataset_dict(self):
# Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset
paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict)
- self.assertEqual(
- unpaired_dataset_dict["abc"].to_dict(),
- self.unpaired_dataset.to_dict(),
- "The paired dataset should be converted to unpaired.",
- )
+ assert unpaired_dataset_dict["abc"].to_dict() == \
+ self.unpaired_dataset.to_dict(), \
+ "The paired dataset should be converted to unpaired."
def test_maybe_unpair_preference_dataset_already_paired(self):
# Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset
unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset)
- self.assertEqual(
- unpaired_dataset.to_dict(),
- self.unpaired_dataset.to_dict(),
- "The unpaired dataset should remain unchanged.",
- )
+ assert unpaired_dataset.to_dict() == \
+ self.unpaired_dataset.to_dict(), \
+ "The unpaired dataset should remain unchanged."
def test_maybe_unpair_preference_dataset_dict_already_paired(self):
# Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset
unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset}))
- self.assertEqual(
- unpaired_dataset_dict["abc"].to_dict(),
- self.unpaired_dataset.to_dict(),
- "The unpaired dataset should remain unchanged.",
- )
+ assert unpaired_dataset_dict["abc"].to_dict() == \
+ self.unpaired_dataset.to_dict(), \
+ "The unpaired dataset should remain unchanged."
class ExtractPromptTester(TrlTestCase):
@@ -767,56 +755,44 @@ class ExtractPromptTester(TrlTestCase):
def test_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational)
- self.assertEqual(
- example_extracted_prompt,
- self.example_explicit_prompt_conversational,
- "The prompt is not correctly extracted from the dataset.",
- )
+ assert example_extracted_prompt == \
+ self.example_explicit_prompt_conversational, \
+ "The prompt is not correctly extracted from the dataset."
def test_maybe_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational)
- self.assertEqual(
- example_extracted_prompt,
- self.example_explicit_prompt_conversational,
- "The prompt is not correctly extracted from the dataset.",
- )
+ assert example_extracted_prompt == \
+ self.example_explicit_prompt_conversational, \
+ "The prompt is not correctly extracted from the dataset."
def test_maybe_extract_prompt_conversational_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational)
- self.assertEqual(
- example_extracted_prompt,
- self.example_explicit_prompt_conversational,
- "The prompt should remain unchanged.",
- )
+ assert example_extracted_prompt == \
+ self.example_explicit_prompt_conversational, \
+ "The prompt should remain unchanged."
def test_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard)
- self.assertEqual(
- example_extracted_prompt,
- self.example_explicit_prompt_standard,
- "The prompt is not correctly extracted from the dataset.",
- )
+ assert example_extracted_prompt == \
+ self.example_explicit_prompt_standard, \
+ "The prompt is not correctly extracted from the dataset."
def test_maybe_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard)
- self.assertEqual(
- example_extracted_prompt,
- self.example_explicit_prompt_standard,
- "The prompt is not correctly extracted from the dataset.",
- )
+ assert example_extracted_prompt == \
+ self.example_explicit_prompt_standard, \
+ "The prompt is not correctly extracted from the dataset."
def test_maybe_extract_prompt_standard_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard)
- self.assertEqual(
- example_extracted_prompt,
- self.example_explicit_prompt_standard,
- "The prompt should remain unchanged.",
- )
+ assert example_extracted_prompt == \
+ self.example_explicit_prompt_standard, \
+ "The prompt should remain unchanged."
class TestPackDatasetWrapped(TrlTestCase):
@@ -832,7 +808,7 @@ def test_with_dataset(self):
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
- self.assertEqual(dataset.to_dict(), expected_output)
+ assert dataset.to_dict() == expected_output
def test_with_iterable_dataset(self):
examples = {
@@ -847,7 +823,7 @@ def test_with_iterable_dataset(self):
}
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
num_examples = len(examples[next(iter(examples))])
- self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
+ assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
class TestPackDatasetBfd(TrlTestCase):
@@ -864,7 +840,7 @@ def test_simple(self):
"seq_lengths": [[4], [3, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
- self.assertEqual(dataset.to_dict(), expected_output)
+ assert dataset.to_dict() == expected_output
def test_with_iterable_dataset(self):
examples = {
@@ -880,7 +856,7 @@ def test_with_iterable_dataset(self):
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
num_examples = len(examples[next(iter(examples))])
- self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
+ assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
def test_with_truncation(self):
examples = {
@@ -895,7 +871,7 @@ def test_with_truncation(self):
"seq_lengths": [[4], [4], [2, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
- self.assertEqual(dataset.to_dict(), expected_output)
+ assert dataset.to_dict() == expected_output
def test_with_non_power_of_2(self):
examples = {
@@ -910,7 +886,7 @@ def test_with_non_power_of_2(self):
"seq_lengths": [[5], [4, 1], [3]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
- self.assertEqual(dataset.to_dict(), expected_output)
+ assert dataset.to_dict() == expected_output
class TestTruncateExamples(TrlTestCase):
@@ -926,7 +902,7 @@ def test_with_dataset(self):
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
- self.assertEqual(dataset.to_dict(), expected_output)
+ assert dataset.to_dict() == expected_output
def test_with_iterable_dataset(self):
examples = {
@@ -941,7 +917,7 @@ def test_with_iterable_dataset(self):
}
dataset = truncate_dataset(dataset, max_length)
num_examples = len(examples[next(iter(examples))])
- self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
+ assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
def test_with_extra_column(self):
examples = {
@@ -957,7 +933,7 @@ def test_with_extra_column(self):
"my_column": ["a", "b", "c"],
}
dataset = truncate_dataset(dataset, max_length)
- self.assertEqual(dataset.to_dict(), expected_output)
+ assert dataset.to_dict() == expected_output
class TestMaybeConvertToChatML(TrlTestCase):
@@ -975,7 +951,7 @@ def test_with_conversations_key(self):
{"role": "assistant", "content": "It is blue."},
]
}
- self.assertEqual(maybe_convert_to_chatml(example), expected_output)
+ assert maybe_convert_to_chatml(example) == expected_output
def test_without_conversations_key(self):
# Same as before, but we don't rename the keys
@@ -987,12 +963,12 @@ def test_without_conversations_key(self):
"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}],
}
- self.assertEqual(maybe_convert_to_chatml(example), expected_output)
+ assert maybe_convert_to_chatml(example) == expected_output
def test_not_conversional(self):
# When not needed, the example should remain unchanged
example = {"text": "The sky is blue."}
- self.assertEqual(maybe_convert_to_chatml(example), example)
+ assert maybe_convert_to_chatml(example) == example
def test_already_chatml(self):
# When the example is already in ChatML format, it should remain unchanged
@@ -1002,7 +978,7 @@ def test_already_chatml(self):
{"role": "assistant", "content": "It is blue."},
]
}
- self.assertEqual(maybe_convert_to_chatml(example), example)
+ assert maybe_convert_to_chatml(example) == example
# Run the tests
diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py
index c85845e34c3..cc132aedefb 100644
--- a/tests/test_dataset_formatting.py
+++ b/tests/test_dataset_formatting.py
@@ -44,20 +44,20 @@ def test_get_formatting_func_from_dataset_with_chatml_messages(self):
# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
- self.assertIsInstance(formatting_func, Callable)
+ assert isinstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
expected = " [INST] You are helpful\n\nHello [/INST] Hi, how can I help you?"
- self.assertEqual(formatted_text, expected)
+ assert formatted_text == expected
formatted_text = formatting_func(dataset[0:1])
- self.assertListEqual(formatted_text, [expected])
+ assert formatted_text == [expected]
# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
- self.assertEqual(formatted_text, expected)
+ assert formatted_text == expected
formatted_text = formatting_func(dataset[0:1])
- self.assertListEqual(formatted_text, [expected])
+ assert formatted_text == [expected]
def test_get_formatting_func_from_dataset_with_chatml_conversations(self):
dataset = Dataset.from_dict(
@@ -73,48 +73,48 @@ def test_get_formatting_func_from_dataset_with_chatml_conversations(self):
)
# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
- self.assertIsInstance(formatting_func, Callable)
+ assert isinstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
expected = " [INST] You are helpful\n\nHello [/INST] Hi, how can I help you?"
- self.assertEqual(formatted_text, expected)
+ assert formatted_text == expected
formatted_text = formatting_func(dataset[0:1])
- self.assertListEqual(formatted_text, [expected])
+ assert formatted_text == [expected]
# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
- self.assertEqual(formatted_text, expected)
+ assert formatted_text == expected
formatted_text = formatting_func(dataset[0:1])
- self.assertListEqual(formatted_text, [expected])
+ assert formatted_text == [expected]
def test_get_formatting_func_from_dataset_with_instruction(self):
dataset = Dataset.from_list(
[{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}]
)
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
- self.assertIsNotNone(formatting_func)
- self.assertIsInstance(formatting_func, Callable)
+ assert formatting_func is not None
+ assert isinstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
- self.assertEqual(formatted_text, " [INST] What is 2+2? [/INST] 4")
+ assert formatted_text == " [INST] What is 2+2? [/INST] 4"
formatted_text = formatting_func(dataset[0:1])
- self.assertListEqual(formatted_text, [" [INST] What is 2+2? [/INST] 4"])
+ assert formatted_text == [" [INST] What is 2+2? [/INST] 4"]
def test_get_formatting_func_from_dataset_from_hub(self):
ds_1 = load_dataset("philschmid/trl-test-instruction", split="train")
ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train")
for ds in [ds_1, ds_2]:
formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer)
- self.assertIsNotNone(formatting_func)
- self.assertIsInstance(formatting_func, Callable)
+ assert formatting_func is not None
+ assert isinstance(formatting_func, Callable)
ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train")
formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer)
- self.assertIsNone(formatting_func)
+ assert formatting_func is None
def test_get_formatting_func_from_dataset_with_unknown_format(self):
dataset = Dataset.from_dict({"text": "test"})
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
- self.assertIsNone(formatting_func)
+ assert formatting_func is None
class SetupChatFormatTestCase(TrlTestCase):
@@ -132,13 +132,13 @@ def test_setup_chat_format(self):
_chatml = ChatMlSpecialTokens()
# Check if special tokens are correctly set
- self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")
- self.assertEqual(modified_tokenizer.pad_token, "<|im_end|>")
- self.assertEqual(modified_tokenizer.bos_token, "<|im_start|>")
- self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token)
- self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token)
- self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token)
- self.assertEqual((modified_model.vocab_size % 123), 0)
+ assert modified_tokenizer.eos_token == "<|im_end|>"
+ assert modified_tokenizer.pad_token == "<|im_end|>"
+ assert modified_tokenizer.bos_token == "<|im_start|>"
+ assert modified_tokenizer.eos_token == _chatml.eos_token
+ assert modified_tokenizer.pad_token == _chatml.pad_token
+ assert modified_tokenizer.bos_token == _chatml.bos_token
+ assert (modified_model.vocab_size % 123) == 0
def test_example_with_setup_model(self):
modified_model, modified_tokenizer = setup_chat_format(
@@ -152,10 +152,8 @@ def test_example_with_setup_model(self):
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
- self.assertEqual(
- prompt,
- "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n",
- )
+ assert prompt == \
+ "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
class CloneChatTemplateTestCase(TrlTestCase):
@@ -168,7 +166,7 @@ def test_clone(self):
_, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source)
# Check if special tokens are correctly set
- self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")
+ assert modified_tokenizer.eos_token == "<|im_end|>"
def test_clone_with_resize(self):
# This tokenizer doesn't have a chat_template by default
@@ -181,9 +179,9 @@ def test_clone_with_resize(self):
)
# Check that the input embeddings have been resized to a multiple of 123
- self.assertEqual((modified_model.vocab_size % 123), 0)
+ assert (modified_model.vocab_size % 123) == 0
# Check that the input embeddings size matches the tokenizer vocabulary size
- self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab))
+ assert model.vocab_size == len(modified_tokenizer.vocab)
def test_clone_with_resize_and_extra_tokens_already_in_vocab(self):
# This tokenizer doesn't have a chat_template by default
@@ -201,9 +199,9 @@ def test_clone_with_resize_and_extra_tokens_already_in_vocab(self):
)
# Check that the input embeddings have been resized to a multiple of 123
- self.assertEqual((modified_model.vocab_size % 124), 0)
+ assert (modified_model.vocab_size % 124) == 0
# Check that the input embeddings size matches the tokenizer vocabulary size
- self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab))
+ assert model.vocab_size == len(modified_tokenizer.vocab)
def test_apply_new_chat_template(self):
# This tokenizer doesn't have a chat_template by default
@@ -219,10 +217,8 @@ def test_apply_new_chat_template(self):
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
- self.assertEqual(
- prompt,
- "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n",
- )
+ assert prompt == \
+ "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n"
def test_clone_with_sequence_classification_model(self):
# This tokenizer doesn't have a chat_template by default
@@ -235,4 +231,4 @@ def test_clone_with_sequence_classification_model(self):
_, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source)
# Check if special tokens are correctly set
- self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")
+ assert modified_tokenizer.eos_token == "<|im_end|>"
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index 6a4bfd22301..f429726c20b 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -40,6 +40,7 @@
from trl import DPOConfig, DPOTrainer, FDivergenceType
from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb
+import pytest
if is_vision_available():
@@ -84,14 +85,12 @@ def test_tokenize_row_no_truncation_no_special_tokens(self):
)
# Assert the correct output without truncation or special tokens
- self.assertEqual(
- result,
+ assert result == \
{
"prompt_input_ids": [464, 6766, 318],
"chosen_input_ids": [4171, 2], # eos_token added
"rejected_input_ids": [4077, 2], # eos_token added
- },
- )
+ }
def test_tokenize_row_with_truncation(self):
# Define the input features
@@ -107,14 +106,12 @@ def test_tokenize_row_with_truncation(self):
)
# Assert the correct output with truncation applied
- self.assertEqual(
- result,
+ assert result == \
{
"prompt_input_ids": [6766, 318], # truncated to the last 2 tokens
"chosen_input_ids": [4171], # truncated to 1 token
"rejected_input_ids": [4077], # truncated to 1 token
- },
- )
+ }
def test_tokenize_row_with_special_tokens(self):
# Define the input features
@@ -130,14 +127,12 @@ def test_tokenize_row_with_special_tokens(self):
)
# Assert the correct output with special tokens added
- self.assertEqual(
- result,
+ assert result == \
{
"prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added
"chosen_input_ids": [4171, 2], # eos_token added
"rejected_input_ids": [4077, 2], # eos_token added
- },
- )
+ }
def test_tokenize_row_with_truncation_and_special_tokens(self):
# Define the input features
@@ -153,14 +148,12 @@ def test_tokenize_row_with_truncation_and_special_tokens(self):
)
# Assert the correct output with both truncation and special tokens
- self.assertEqual(
- result,
+ assert result == \
{
"prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token
"chosen_input_ids": [4171], # truncated to 1 token
"rejected_input_ids": [4077], # truncated to 1 token
- },
- )
+ }
class DPOTrainerTester(TrlTestCase):
@@ -193,13 +186,13 @@ def test_train(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
@parameterized.expand(
[
@@ -241,13 +234,13 @@ def test_train_loss_types(self, loss_type):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
@require_liger_kernel
def test_train_encoder_decoder_liger(self):
@@ -274,13 +267,13 @@ def test_train_encoder_decoder_liger(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
def test_dpo_trainer_with_weighting(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
@@ -304,13 +297,13 @@ def test_dpo_trainer_with_weighting(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
def test_train_with_multiple_loss_types(self):
"""
@@ -338,22 +331,22 @@ def test_train_with_multiple_loss_types(self):
# Test that training works
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Verify SFT loss is computed in the first test too
with torch.no_grad():
batch = next(iter(trainer.get_train_dataloader()))
loss, metrics = trainer.get_batch_loss_metrics(trainer.model, batch)
- self.assertIn("nll_loss", metrics) # SFT loss should be computed
+ assert "nll_loss" in metrics # SFT loss should be computed
def test_wrong_loss_weights_length(self):
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as context:
DPOConfig(
output_dir=self.tmp_dir,
loss_type=["sigmoid", "bco_pair"],
loss_weights=[1.0, 0.5, 0.1], # Wrong length
)
- self.assertIn("Length of loss_weights list", str(context.exception))
+ assert "Length of loss_weights list" in str(context.exception)
@parameterized.expand([(None,), (0.5,)])
def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha):
@@ -386,13 +379,13 @@ def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
def test_dpo_trainer_with_ref_model_is_model(self):
training_args = DPOConfig(
@@ -404,7 +397,7 @@ def test_dpo_trainer_with_ref_model_is_model(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
DPOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
@@ -437,13 +430,13 @@ def test_precompute_ref_batch_size(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
@require_peft
def test_dpo_trainer_without_providing_ref_model_with_lora(self):
@@ -486,14 +479,14 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
def test_dpo_trainer_w_dataset_num_proc(self):
training_args = DPOConfig(
@@ -555,13 +548,13 @@ def test_tr_dpo_trainer(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.ref_model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
@@ -580,11 +573,8 @@ def test_dpo_trainer_generate_during_eval_no_wandb(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
- with self.assertRaisesRegex(
- ValueError,
- expected_regex="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed."
- " Please install `wandb`, `mlflow` or `comet-ml` to resolve.",
- ):
+ with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed."
+ " Please install `wandb`, `mlflow` or `comet-ml` to resolve."):
DPOTrainer(
model=self.model,
ref_model=None,
@@ -645,7 +635,7 @@ def test_dpo_lora_save(self):
try:
AutoModelForCausalLM.from_pretrained(self.tmp_dir)
except OSError:
- self.fail("Loading the saved peft adapter failed")
+ pytest.fail("Loading the saved peft adapter failed")
@require_peft
@require_torch_gpu_if_bnb_not_multi_backend_enabled
@@ -826,7 +816,7 @@ def test_dpo_lora_tags(self):
)
for tag in ["dpo", "trl"]:
- self.assertIn(tag, trainer.model.model_tags)
+ assert tag in trainer.model.model_tags
@require_peft
def test_dpo_tags(self):
@@ -861,7 +851,7 @@ def test_dpo_tags(self):
)
for tag in ["dpo", "trl"]:
- self.assertIn(tag, trainer.model.model_tags)
+ assert tag in trainer.model.model_tags
@require_peft
def test_dpo_lora_force_use_ref(self):
@@ -895,7 +885,7 @@ def test_dpo_lora_force_use_ref(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
# passing a peft_model as model and ref_model should error out,
# unless you pass `force_use_ref_model`
trainer = DPOTrainer(
@@ -953,8 +943,8 @@ def test_dpo_trainer_dtype(self):
args=training_args,
train_dataset=dummy_dataset["train"],
)
- self.assertEqual(trainer.model.config.dtype, torch.float16)
- self.assertEqual(trainer.ref_model.config.dtype, torch.float16)
+ assert trainer.model.config.dtype == torch.float16
+ assert trainer.ref_model.config.dtype == torch.float16
# Now test when `dtype` is provided but is wrong to either the model or the ref_model
training_args = DPOConfig(
@@ -965,7 +955,7 @@ def test_dpo_trainer_dtype(self):
report_to="none",
)
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as context:
_ = DPOTrainer(
model=self.model_id,
processing_class=self.tokenizer,
@@ -973,11 +963,9 @@ def test_dpo_trainer_dtype(self):
train_dataset=dummy_dataset["train"],
)
- self.assertIn(
- "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid "
- "`torch.dtype` (e.g., 'float32'), but got -1.",
- str(context.exception),
- )
+ assert "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " \
+ "`torch.dtype` (e.g., 'float32'), but got -1." in \
+ str(context.exception)
training_args = DPOConfig(
output_dir=self.tmp_dir,
@@ -987,7 +975,7 @@ def test_dpo_trainer_dtype(self):
report_to="none",
)
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as context:
_ = DPOTrainer(
model=self.model_id,
ref_model=self.model_id,
@@ -996,11 +984,9 @@ def test_dpo_trainer_dtype(self):
train_dataset=dummy_dataset["train"],
)
- self.assertIn(
- "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid "
- "`torch.dtype` (e.g., 'float32'), but got -1.",
- str(context.exception),
- )
+ assert "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " \
+ "`torch.dtype` (e.g., 'float32'), but got -1." in \
+ str(context.exception)
def test_dpo_loss_alpha_div_f(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -1041,7 +1027,7 @@ def test_dpo_loss_alpha_div_f(self):
losses, _, _ = trainer.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
- self.assertTrue(torch.isfinite(losses).cpu().numpy().all())
+ assert torch.isfinite(losses).cpu().numpy().all()
def test_dpo_loss_js_div_f(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -1083,7 +1069,7 @@ def test_dpo_loss_js_div_f(self):
losses, _, _ = trainer.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
- self.assertTrue(torch.isfinite(losses).cpu().numpy().all())
+ assert torch.isfinite(losses).cpu().numpy().all()
def test_dpo_trainer_use_logits_to_keep(self):
model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2"
@@ -1199,7 +1185,7 @@ def get_current_temperature(location: str):
# We don't run the training, but at this stage, the dataset is supposed to be pre-processed. When
# pre-processing, we expect the available tools to be explicitly mentioned in the system prompt. That's
# what we're checking here
- self.assertIn("get_current_temperature", tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0]))
+ assert "get_current_temperature" in tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0])
def test_padding_free(self):
model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2"
@@ -1235,7 +1221,7 @@ def test_padding_free(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
def test_compute_metrics(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -1270,7 +1256,7 @@ def dummy_compute_metrics(*args, **kwargs):
trainer.train()
- self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
+ assert trainer.state.log_history[-2]["eval_test"] == 0.0
def test_train_with_length_desensitization(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -1295,13 +1281,13 @@ def test_train_with_length_desensitization(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
@unittest.skipUnless(sys.version_info >= (3, 10), "Liger kernel is not supported on Python 3.9")
@parameterized.expand(
@@ -1359,20 +1345,20 @@ def test_dpo_trainer_with_liger(self, beta, loss_type):
train_output = trainer.train()
# Verify training completed successfully
- self.assertIsNotNone(train_output)
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert train_output is not None
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Verify loss is finite
- self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"]))
+ assert np.isfinite(trainer.state.log_history[-1]["train_loss"])
# Check parameters have been updated
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# Only check non-zero parameters
if param.sum() != 0:
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
# Verify new parameters are finite
- self.assertTrue(torch.isfinite(new_param).all())
+ assert torch.isfinite(new_param).all()
# Verify model can still do forward pass after training
dummy_batch = next(iter(trainer.get_train_dataloader()))
@@ -1382,8 +1368,8 @@ def test_dpo_trainer_with_liger(self, beta, loss_type):
}
with torch.no_grad():
output = trainer.model(**model_inputs)
- self.assertIsNotNone(output)
- self.assertFalse("loss" in output.keys())
+ assert output is not None
+ assert not ("loss" in output.keys())
def test_train_with_iterable_dataset(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -1411,13 +1397,13 @@ def test_train_with_iterable_dataset(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
@require_vision
@@ -1494,7 +1480,7 @@ def test_vdpo_trainer(self, model_id):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the trainable params have changed
for n, param in previous_trainable_params.items():
@@ -1510,7 +1496,7 @@ def test_vdpo_trainer(self, model_id):
# For some reason, these params are not updated. This is probably not related to TRL, but to
# the model itself. We should investigate this further, but for now we just skip these params.
continue
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
class TestDPOConfig(TrlTestCase):
diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py
index 4a0d458440c..11ddcaedc81 100644
--- a/tests/test_gkd_trainer.py
+++ b/tests/test_gkd_trainer.py
@@ -70,10 +70,8 @@ def test_generate_on_policy_outputs_deterministic(self):
# Check if the generated texts start with the original prompts
for prompt, generated_text in zip(prompts, generated_texts):
- self.assertTrue(
- generated_text.startswith(prompt),
- f"Generated text '{generated_text}' does not start with prompt '{prompt}'",
- )
+ assert generated_text.startswith(prompt), \
+ f"Generated text '{generated_text}' does not start with prompt '{prompt}'"
# Run the generation twice and check if the outputs are identical
outputs2 = GKDTrainer.generate_on_policy_outputs(
@@ -83,15 +81,11 @@ def test_generate_on_policy_outputs_deterministic(self):
new_input_ids2, new_attention_mask2, new_labels2 = outputs2
# Check if the two generations are identical
- self.assertTrue(torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical")
- self.assertTrue(
- torch.all(new_attention_mask.eq(new_attention_mask2)),
- "Attention masks for deterministic generations are not identical",
- )
- self.assertTrue(
- torch.all(new_labels.eq(new_labels2)),
- "Labels for deterministic generations are not identical",
- )
+ assert torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical"
+ assert torch.all(new_attention_mask.eq(new_attention_mask2)), \
+ "Attention masks for deterministic generations are not identical"
+ assert torch.all(new_labels.eq(new_labels2)), \
+ "Labels for deterministic generations are not identical"
def test_generate_on_policy_outputs(self):
prompts = ["Hello, how are you?", "What's the weather like today?"]
@@ -107,25 +101,25 @@ def test_generate_on_policy_outputs(self):
)
# Check that outputs is a tuple of three tensors
- self.assertIsInstance(outputs, tuple)
- self.assertEqual(len(outputs), 3)
+ assert isinstance(outputs, tuple)
+ assert len(outputs) == 3
new_input_ids, new_attention_mask, new_labels = outputs
# Check shapes
batch_size = len(prompts)
- self.assertEqual(new_input_ids.shape[0], batch_size)
- self.assertEqual(new_attention_mask.shape[0], batch_size)
- self.assertEqual(new_labels.shape[0], batch_size)
+ assert new_input_ids.shape[0] == batch_size
+ assert new_attention_mask.shape[0] == batch_size
+ assert new_labels.shape[0] == batch_size
# Check types
- self.assertIsInstance(new_input_ids, torch.Tensor)
- self.assertIsInstance(new_attention_mask, torch.Tensor)
- self.assertIsInstance(new_labels, torch.Tensor)
+ assert isinstance(new_input_ids, torch.Tensor)
+ assert isinstance(new_attention_mask, torch.Tensor)
+ assert isinstance(new_labels, torch.Tensor)
# Check that new_input_ids and new_attention_mask have the same shape
- self.assertEqual(new_input_ids.shape, new_attention_mask.shape)
- self.assertEqual(new_labels.shape, new_attention_mask.shape)
+ assert new_input_ids.shape == new_attention_mask.shape
+ assert new_labels.shape == new_attention_mask.shape
class TestGeneralizedJSDLoss(TrlTestCase):
@@ -140,7 +134,7 @@ def setUp(self):
def test_uniform_distribution(self):
logits = torch.ones(1, 1, self.vocab_size)
loss = GKDTrainer.generalized_jsd_loss(logits, logits)
- self.assertAlmostEqual(loss.item(), 0, places=5)
+ assert round(abs(loss.item()-0), 5) == 0
def test_generalized_jsd_loss_edge_cases(self):
# Setup
@@ -152,29 +146,29 @@ def test_generalized_jsd_loss_edge_cases(self):
expected_loss_beta_1 = F.kl_div(
F.log_softmax(teacher_logits, dim=-1), F.softmax(student_logits, dim=-1), reduction="batchmean"
)
- self.assertAlmostEqual(loss_beta_1.item(), expected_loss_beta_1.item(), places=5)
+ assert round(abs(loss_beta_1.item()-expected_loss_beta_1.item()), 5) == 0
# Case 2: beta = 0 (should be equivalent to KL(teacher || student))
loss_beta_0 = GKDTrainer.generalized_jsd_loss(student_logits, teacher_logits, beta=0)
expected_loss_beta_0 = F.kl_div(
F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction="batchmean"
)
- self.assertAlmostEqual(loss_beta_0.item(), expected_loss_beta_0.item(), places=5)
+ assert round(abs(loss_beta_0.item()-expected_loss_beta_0.item()), 5) == 0
def test_output_shape(self):
loss = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits)
- self.assertTrue(torch.is_tensor(loss))
- self.assertEqual(loss.shape, torch.Size([]))
+ assert torch.is_tensor(loss)
+ assert loss.shape == torch.Size([])
def test_beta_values(self):
loss_beta_0 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0)
loss_beta_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=1)
- self.assertNotEqual(loss_beta_0, loss_beta_1)
+ assert loss_beta_0 != loss_beta_1
def test_temperature_scaling(self):
loss_temp_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=1)
loss_temp_2 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=2)
- self.assertNotEqual(loss_temp_1, loss_temp_2)
+ assert loss_temp_1 != loss_temp_2
def test_reduction_methods(self):
loss_batchmean = GKDTrainer.generalized_jsd_loss(
@@ -184,24 +178,24 @@ def test_reduction_methods(self):
loss_mean = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="mean")
loss_none = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="none")
- self.assertEqual(loss_batchmean.shape, torch.Size([]))
- self.assertEqual(loss_sum.shape, torch.Size([]))
- self.assertEqual(loss_mean.shape, torch.Size([]))
- self.assertEqual(loss_none.shape, self.student_logits.shape)
+ assert loss_batchmean.shape == torch.Size([])
+ assert loss_sum.shape == torch.Size([])
+ assert loss_mean.shape == torch.Size([])
+ assert loss_none.shape == self.student_logits.shape
def test_symmetry(self):
student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.1)
teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.1)
- self.assertNotEqual(student_teacher, teacher_student)
+ assert student_teacher != teacher_student
student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.5)
teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.5)
- self.assertEqual(student_teacher, teacher_student)
+ assert student_teacher == teacher_student
def test_zero_loss_for_identical_inputs(self):
identical_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size)
loss = GKDTrainer.generalized_jsd_loss(identical_logits, identical_logits)
- self.assertAlmostEqual(loss.item(), 0, places=6)
+ assert round(abs(loss.item()-0), 6) == 0
class GKDTrainerTester(TrlTestCase):
@@ -242,9 +236,9 @@ def test_gkd_trainer(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
- self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
- self.assertIn("model.safetensors", os.listdir(self.tmp_dir + "/checkpoint-2"))
+ assert trainer.state.log_history[(-1)]["train_loss"] is not None
+ assert trainer.state.log_history[0]["eval_loss"] is not None
+ assert "model.safetensors" in os.listdir(self.tmp_dir + "/checkpoint-2")
@require_liger_kernel
@pytest.mark.xfail(reason="Computing the Liger loss spikes GPU memory usage, causing the test to run OOM.")
@@ -271,7 +265,7 @@ def test_gkd_trainer_with_liger(self):
trainer.train()
# Check we logged a train loss
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
def test_generation_config_init(self):
training_args = GKDConfig(output_dir=self.tmp_dir)
@@ -286,8 +280,8 @@ def test_generation_config_init(self):
processing_class=self.tokenizer,
)
- self.assertEqual(trainer.generation_config.pad_token_id, self.tokenizer.eos_token_id)
- self.assertEqual(trainer.generation_config.eos_token_id, self.model.generation_config.eos_token_id)
- self.assertEqual(trainer.generation_config.max_new_tokens, training_args.max_new_tokens)
- self.assertEqual(trainer.generation_config.temperature, training_args.temperature)
- self.assertEqual(trainer.generation_config.top_k, 0)
+ assert trainer.generation_config.pad_token_id == self.tokenizer.eos_token_id
+ assert trainer.generation_config.eos_token_id == self.model.generation_config.eos_token_id
+ assert trainer.generation_config.max_new_tokens == training_args.max_new_tokens
+ assert trainer.generation_config.temperature == training_args.temperature
+ assert trainer.generation_config.top_k == 0
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index a839a654bca..9a442ee6102 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -148,12 +148,12 @@ def test_training(self, config_name):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@parameterized.expand([("bnpo",), ("dr_grpo",), ("dapo",)])
def test_training_loss_types(self, loss_type):
@@ -180,12 +180,12 @@ def test_training_loss_types(self, loss_type):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_eval(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
@@ -233,12 +233,12 @@ def test_training_multiple_iterations(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_peft
def test_training_peft(self):
@@ -266,15 +266,15 @@ def test_training_peft(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model params to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed."
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
@require_peft
def test_training_peft_with_gradient_checkpointing(self):
@@ -308,22 +308,22 @@ def test_training_peft_with_gradient_checkpointing(self):
)
# Verify gradient checkpointing is enabled
- self.assertIsInstance(trainer.model, PeftModel)
+ assert isinstance(trainer.model, PeftModel)
# Store initial parameters to check which ones change
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that only LoRA parameters have changed, base model parameters remain unchanged
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "lora" in n.lower(): # LoRA parameters should change
- self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"LoRA parameter {n} has not changed."
else: # Base model parameters should not change
- self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.")
+ assert torch.equal(param, new_param), f"Base parameter {n} has changed."
def test_training_different_reward_model(self):
# Use a reward model different from the model: different chat template, tokenization, etc.
@@ -357,12 +357,12 @@ def test_training_different_reward_model(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_reward_func_standard(self):
# Test if trainer can handle reward function with standard format
@@ -391,12 +391,12 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_reward_func_conversational(self):
# Test if trainer can handle reward function with conversational format
@@ -426,12 +426,12 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_reward_funcs(self):
# Test that GRPOTrainer can be instantiated with multiple reward functions
@@ -464,12 +464,12 @@ def reward_func2(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_reward_funcs_with_None_output(self):
"""Test that a valid math reward function is processed correctly while the code reward function returns None."""
@@ -508,12 +508,12 @@ def non_applicable_reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_reward_funcs_with_weights(self):
"""Test that GRPOTrainer can handle multiple reward functions with weights."""
@@ -548,16 +548,16 @@ def reward_func2(completions, **kwargs):
trainer.train()
# Check that training logs contain both reward metrics
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
- self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1])
- self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1])
- self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1])
- self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ assert "rewards/reward_func1/mean" in trainer.state.log_history[-1]
+ assert "rewards/reward_func1/std" in trainer.state.log_history[-1]
+ assert "rewards/reward_func2/mean" in trainer.state.log_history[-1]
+ assert "rewards/reward_func2/std" in trainer.state.log_history[-1]
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_mixed_reward_funcs(self):
# Test if the trainer can handle a mix of reward functions and reward models
@@ -586,12 +586,12 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_reward_func_additional_column(self):
# Test if trainer can handle reward function that rely on additional columns in the dataset
@@ -624,12 +624,12 @@ def reward_func(completions, some_values, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_sync_ref_model(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -655,12 +655,12 @@ def test_training_with_sync_ref_model(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_beta_non_zero(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -684,12 +684,12 @@ def test_training_beta_non_zero(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_entropy_filter(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -713,12 +713,12 @@ def test_training_with_entropy_filter(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@unittest.skip("We should add a mock for the vLLM server.")
@require_peft
@@ -755,16 +755,16 @@ def test_training_vllm_and_peft(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model params to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed."
elif "base_layer" not in n and "original_module" not in n:
# We expect the peft params to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
@require_vllm
@unittest.skip("We should add a mock for the vLLM server.")
@@ -793,12 +793,12 @@ def test_training_vllm_guided_decoding(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vllm
@unittest.skip("We should add a mock for the vLLM server.")
@@ -828,12 +828,12 @@ def test_training_vllm_importance_sampling_correction(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_additional_generation_kwargs(self):
"""Test that training works with additional generation kwargs."""
@@ -863,12 +863,12 @@ def test_training_with_additional_generation_kwargs(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vllm
@unittest.skip("We should add a mock for the vLLM server.")
@@ -901,12 +901,12 @@ def test_training_vllm_with_additional_generation_kwargs(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@parameterized.expand([(False,), ("group",), ("batch",), (True,), ("none",)])
def test_training_scale_rewards(self, scale_rewards):
@@ -932,12 +932,12 @@ def test_training_scale_rewards(self, scale_rewards):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@patch("transformers.generation.utils.GenerationMixin.generate")
def test_training_with_mask_truncated_completions(self, mock_generate):
@@ -982,12 +982,12 @@ def fake_generate(input_ids, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_mask_truncated_completions_all_masked(self):
"""
@@ -1020,12 +1020,12 @@ def test_training_with_mask_truncated_completions_all_masked(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.")
+ assert torch.equal(param, new_param), f"Parameter {n} has changed."
def test_warning_raised_all_rewards_none(self):
"""Test that a proper warning is raised when all rewards are None."""
@@ -1054,7 +1054,7 @@ def always_none_reward_func(completions, **kwargs):
trainer.train()
expected_warning = "All reward functions returned None for the following kwargs:"
- self.assertIn(expected_warning, cm.output[0])
+ assert expected_warning in cm.output[0]
def test_training_num_generations_larger_than_batch_size(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1079,12 +1079,12 @@ def test_training_num_generations_larger_than_batch_size(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_delta_clipping(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1109,12 +1109,12 @@ def test_training_delta_clipping(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_dataloader_workers(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1139,12 +1139,12 @@ def test_training_multiple_dataloader_workers(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_generation_kwargs(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1169,12 +1169,12 @@ def test_training_with_generation_kwargs(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_reward_func_accessing_trainer_state(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1246,7 +1246,7 @@ def test_prepare_input_called_with_correct_data(self):
with patch.object(GRPOTrainer, "training_step", wraps=trainer.training_step) as mock_prepare:
trainer.train()
# 3 epochs * 2 iterations * 2 generation batches to cover the dataset * 4 steps_per_generation
- self.assertEqual(mock_prepare.call_count, 48)
+ assert mock_prepare.call_count == 48
for i in range(0, 8): # Generation batch repeated 8 times (steps_per_generation*num_iterations)
assert mock_prepare.call_args_list[i].args[1] == expected_first_generation_batch
for i in range(8, 16):
@@ -1289,7 +1289,7 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
@@ -1305,7 +1305,7 @@ def reward_func(completions, **kwargs):
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
def test_training_vlm_beta_non_zero(self):
@@ -1335,7 +1335,7 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
@@ -1345,7 +1345,7 @@ def reward_func(completions, **kwargs):
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
@require_peft
@@ -1380,15 +1380,15 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model params to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed."
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
@require_vision
def test_training_vlm_and_importance_sampling(self):
@@ -1418,7 +1418,7 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
@@ -1428,7 +1428,7 @@ def reward_func(completions, **kwargs):
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
@require_liger_kernel
@@ -1460,7 +1460,7 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
@@ -1470,7 +1470,7 @@ def reward_func(completions, **kwargs):
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
def test_training_vlm_and_prompt_truncation(self):
@@ -1501,7 +1501,7 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
@@ -1511,7 +1511,7 @@ def reward_func(completions, **kwargs):
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
@require_vllm
@@ -1551,11 +1551,11 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
def test_training_vlm_multi_image(self):
@@ -1588,14 +1588,14 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_sequence_importance_sampling(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1621,12 +1621,12 @@ def test_training_sequence_importance_sampling(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_mismatched_reward_processing_classes_length(self):
"""Test that mismatched length between reward_funcs and reward_processing_classes raises error."""
@@ -1645,7 +1645,7 @@ def test_mismatched_reward_processing_classes_length(self):
training_args = GRPOConfig(output_dir=self.tmp_dir, report_to="none")
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as context:
GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_models,
@@ -1654,7 +1654,7 @@ def test_mismatched_reward_processing_classes_length(self):
train_dataset=dataset,
)
- self.assertIn("must match", str(context.exception))
+ assert "must match" in str(context.exception)
def test_correct_reward_processing_classes_list(self):
"""Test that correct list of reward_processing_classes works properly."""
@@ -1685,7 +1685,7 @@ def test_correct_reward_processing_classes_list(self):
train_dataset=dataset,
)
- self.assertEqual(len(trainer.reward_processing_classes), len(reward_models))
+ assert len(trainer.reward_processing_classes) == len(reward_models)
def test_single_reward_model_with_single_processing_class(self):
"""Test that single reward model with single processing class works."""
@@ -1709,8 +1709,8 @@ def test_single_reward_model_with_single_processing_class(self):
train_dataset=dataset,
)
- self.assertEqual(len(trainer.reward_processing_classes), 1)
- self.assertEqual(trainer.reward_processing_classes[0], single_processing_class)
+ assert len(trainer.reward_processing_classes) == 1
+ assert trainer.reward_processing_classes[0] == single_processing_class
@pytest.mark.low_priority
@@ -1731,12 +1731,12 @@ def test_add(self):
self.replay_buffer.add(scores, data)
# Check if the buffer contains the correct number of elements
- self.assertEqual(len(self.replay_buffer.heap), 5)
+ assert len(self.replay_buffer.heap) == 5
# Check if the buffer maintains the min-heap property
heap_scores = [item[0] for item in self.replay_buffer.heap]
- self.assertEqual(heap_scores[0], min(heap_scores))
- self.assertEqual(heap_scores[0], 0.3)
+ assert heap_scores[0] == min(heap_scores)
+ assert heap_scores[0] == 0.3
def test_add_more_than_maxlen(self):
# Add elements to the replay buffer
@@ -1753,12 +1753,12 @@ def test_add_more_than_maxlen(self):
self.replay_buffer.add(scores, data)
# Check if the buffer contains the correct number of elements
- self.assertEqual(len(self.replay_buffer.heap), 5)
+ assert len(self.replay_buffer.heap) == 5
# Check if the buffer maintains the min-heap property
heap_scores = [item[0] for item in self.replay_buffer.heap]
- self.assertEqual(heap_scores[0], min(heap_scores))
- self.assertEqual(heap_scores[0], 0.5) # 0.3 and 0.4 should be removed
+ assert heap_scores[0] == min(heap_scores)
+ assert heap_scores[0] == 0.5 # 0.3 and 0.4 should be removed
def test_sample(self):
# Add elements to the replay buffer
@@ -1776,9 +1776,9 @@ def test_sample(self):
sampled = self.replay_buffer.sample(num_samples=3)
# Check if the sampled elements are from the buffer
- self.assertEqual(len(sampled), 3)
+ assert len(sampled) == 3
for item in sampled:
- self.assertIn(item, [entry[1] for entry in self.replay_buffer.heap])
+ assert item in [entry[1] for entry in self.replay_buffer.heap]
@pytest.mark.low_priority
@@ -1841,12 +1841,12 @@ def test_update_with_replay_buffer_no_variance(self):
outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
- self.assertIsNotNone(outputs)
- self.assertIn("pixel_values", outputs)
- self.assertIn("old_per_token_logps", outputs)
- self.assertEqual(len(self.trainer.replay_buffer.heap), 2)
+ assert outputs is not None
+ assert "pixel_values" in outputs
+ assert "old_per_token_logps" in outputs
+ assert len(self.trainer.replay_buffer.heap) == 2
for pid in outputs["prompt_ids"]:
- self.assertNotIn(pid.tolist(), original_prompt_ids.tolist())
+ assert pid.tolist() not in original_prompt_ids.tolist()
def test_update_with_replay_buffer_with_variance(self):
self._prepopulate_buffer()
@@ -1855,8 +1855,8 @@ def test_update_with_replay_buffer_with_variance(self):
sampled = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
- self.assertEqual(len(self.trainer.replay_buffer.heap), 4) # grew
- self.assertIsNone(sampled)
+ assert len(self.trainer.replay_buffer.heap) == 4 # grew
+ assert sampled is None
def test_update_with_mixed_variance(self):
self._prepopulate_buffer()
@@ -1866,16 +1866,16 @@ def test_update_with_mixed_variance(self):
outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
- self.assertEqual(len(self.trainer.replay_buffer.heap), 3) # grew by 1
+ assert len(self.trainer.replay_buffer.heap) == 3 # grew by 1
output_prompt_ids = outputs["prompt_ids"].view(-1, self.trainer.num_generations, 2).tolist()
buffer_ids = [item[1]["prompt_ids"].tolist() for item in self.trainer.replay_buffer.heap]
found_from_buffer = any(pid in buffer_ids for pid in output_prompt_ids)
found_from_original = any(pid in original_prompt_ids for pid in output_prompt_ids)
- self.assertTrue(found_from_buffer)
- self.assertTrue(found_from_original)
- self.assertNotIn([[1, 2], [3, 4]], output_prompt_ids) # excluded no-variance group
+ assert found_from_buffer
+ assert found_from_original
+ assert [[1, 2], [3, 4]] not in output_prompt_ids # excluded no-variance group
def test_update_with_inputs_different_seq_len(self):
"""
@@ -1910,8 +1910,8 @@ def test_update_with_inputs_different_seq_len(self):
outputs_after_sampling = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
# Seq length of current batch should be preserved
- self.assertEqual(outputs_after_sampling["prompt_ids"].shape[-1], 3)
- self.assertEqual(len(self.trainer.replay_buffer.heap), 3)
+ assert outputs_after_sampling["prompt_ids"].shape[-1] == 3
+ assert len(self.trainer.replay_buffer.heap) == 3
output_prompt_ids = outputs_after_sampling["prompt_ids"].view(-1, self.trainer.num_generations, 3).tolist()
buffered_prompt_completion_ids = [
@@ -1921,24 +1921,20 @@ def test_update_with_inputs_different_seq_len(self):
buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids)
# Check for new entry with seq len 3 in buffer
- self.assertIn([[3, 4, 5], [3, 4, 5]], buffered_prompt_ids) # excluded no-variance group
- self.assertIn(
- [[1013, 1014, pad_token_id], [1015, 1016, 1017]], buffered_completion_ids
- ) # excluded no-variance group
+ assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group
+ assert [[1013, 1014, pad_token_id], [1015, 1016, 1017]] in buffered_completion_ids # excluded no-variance group
# Check that sampled outputs contain one group with prompt_ids starting with a pad token
- self.assertTrue(
- [
+ assert [
[pad_token_id, 101, 102],
[pad_token_id, 102, 103],
- ]
- in output_prompt_ids
+ ] \
+ in output_prompt_ids \
or [
[pad_token_id, 104, 105],
[pad_token_id, 106, 107],
- ]
+ ] \
in output_prompt_ids
- )
@pytest.mark.low_priority
@@ -1973,12 +1969,12 @@ def custom_reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
class GSPOTokenTrainerTester(TrlTestCase):
@@ -2006,12 +2002,12 @@ def test_training(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
if __name__ == "__main__":
diff --git a/tests/test_judges.py b/tests/test_judges.py
index cce8f961a5f..9238a7f5cb2 100644
--- a/tests/test_judges.py
+++ b/tests/test_judges.py
@@ -35,17 +35,17 @@ def test_all_true_judge(self):
judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
prompts, completions = self._get_prompts_and_single_completions()
judgements = judge.judge(prompts=prompts, completions=completions)
- self.assertEqual(len(judgements), 2)
- self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements))
+ assert len(judgements) == 2
+ assert all(judgement in {0, 1, -1} for judgement in judgements)
@unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.")
def test_hugging_face_judge(self):
judge = HfPairwiseJudge()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
- self.assertEqual(len(ranks), 2)
- self.assertTrue(all(isinstance(rank, int) for rank in ranks))
- self.assertEqual(ranks, [0, 1])
+ assert len(ranks) == 2
+ assert all(isinstance(rank, int) for rank in ranks)
+ assert ranks == [0, 1]
def load_pair_rm_judge(self):
# When using concurrent tests, PairRM may fail to load the model while another job is still downloading.
@@ -62,15 +62,15 @@ def test_pair_rm_judge(self):
judge = self.load_pair_rm_judge()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
- self.assertEqual(len(ranks), 2)
- self.assertTrue(all(isinstance(rank, int) for rank in ranks))
- self.assertEqual(ranks, [0, 1])
+ assert len(ranks) == 2
+ assert all(isinstance(rank, int) for rank in ranks)
+ assert ranks == [0, 1]
@require_llm_blender
def test_pair_rm_judge_return_scores(self):
judge = self.load_pair_rm_judge()
prompts, completions = self._get_prompts_and_pairwise_completions()
probs = judge.judge(prompts=prompts, completions=completions, return_scores=True)
- self.assertEqual(len(probs), 2)
- self.assertTrue(all(isinstance(prob, float) for prob in probs))
- self.assertTrue(all(0 <= prob <= 1 for prob in probs))
+ assert len(probs) == 2
+ assert all(isinstance(prob, float) for prob in probs)
+ assert all(0 <= prob <= 1 for prob in probs)
diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py
index 21b425fec05..ac68a00d9a7 100644
--- a/tests/test_kto_trainer.py
+++ b/tests/test_kto_trainer.py
@@ -23,6 +23,7 @@
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
from .testing_utils import TrlTestCase, require_no_wandb
+import pytest
class KTOTrainerTester(TrlTestCase):
@@ -91,13 +92,13 @@ def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_datas
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
def test_kto_trainer_with_ref_model_is_model(self):
training_args = KTOConfig(
@@ -109,7 +110,7 @@ def test_kto_trainer_with_ref_model_is_model(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
KTOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
@@ -149,13 +150,13 @@ def test_tokenize_and_process_tokens(self):
batched=True,
batch_size=2,
)
- self.assertListEqual(tokenized_dataset["prompt"][:], train_dataset["prompt"][:])
- self.assertListEqual(tokenized_dataset["completion"][:], train_dataset["completion"][:])
- self.assertListEqual(tokenized_dataset["label"][:], train_dataset["label"][:])
- self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
- self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
- self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13])
- self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1])
+ assert tokenized_dataset["prompt"][:] == train_dataset["prompt"][:]
+ assert tokenized_dataset["completion"][:] == train_dataset["completion"][:]
+ assert tokenized_dataset["label"][:] == train_dataset["label"][:]
+ assert tokenized_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091]
+ assert tokenized_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1]
+ assert tokenized_dataset["answer_input_ids"][0] == [27261, 13]
+ assert tokenized_dataset["answer_attention_mask"][0] == [1, 1]
# Test corruption of (prompt, completion) pairs for KL dataset
for batch_size in [2, 3]:
@@ -166,18 +167,12 @@ def test_tokenize_and_process_tokens(self):
# the last batch remains unaltered. This is a rare scenario that does not impact the training
# process, so we exclude it from testing by iterating only up to len - 1.
for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1):
- self.assertListEqual(
- tokenized_dataset["prompt_input_ids"][i],
- tokenized_kl_dataset["prompt_input_ids"][i],
- )
- self.assertListEqual(
- tokenized_dataset["prompt_attention_mask"][i],
- tokenized_kl_dataset["prompt_attention_mask"][i],
- )
- self.assertNotEqual(
- tokenized_dataset["answer_input_ids"][i],
- tokenized_kl_dataset["answer_input_ids"][i],
- )
+ assert tokenized_dataset["prompt_input_ids"][i] == \
+ tokenized_kl_dataset["prompt_input_ids"][i]
+ assert tokenized_dataset["prompt_attention_mask"][i] == \
+ tokenized_kl_dataset["prompt_attention_mask"][i]
+ assert tokenized_dataset["answer_input_ids"][i] != \
+ tokenized_kl_dataset["answer_input_ids"][i]
fn_kwargs = {
"prefix": "",
@@ -189,14 +184,14 @@ def test_tokenize_and_process_tokens(self):
"max_prompt_length": trainer.max_prompt_length,
}
processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2)
- self.assertListEqual(processed_dataset["prompt"][:], train_dataset["prompt"][:])
- self.assertListEqual(processed_dataset["completion"][:], train_dataset["completion"][:])
- self.assertListEqual(processed_dataset["label"][:], train_dataset["label"][:])
- self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
- self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
- self.assertListEqual(processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645])
- self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
- self.assertListEqual(processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645])
+ assert processed_dataset["prompt"][:] == train_dataset["prompt"][:]
+ assert processed_dataset["completion"][:] == train_dataset["completion"][:]
+ assert processed_dataset["label"][:] == train_dataset["label"][:]
+ assert processed_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091]
+ assert processed_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1]
+ assert processed_dataset["completion_input_ids"][0] == [46518, 374, 2664, 1091, 27261, 13, 151645]
+ assert processed_dataset["completion_attention_mask"][0] == [1, 1, 1, 1, 1, 1, 1]
+ assert processed_dataset["completion_labels"][0] == [-100, -100, -100, -100, 27261, 13, 151645]
def test_kto_trainer_without_providing_ref_model(self):
training_args = KTOConfig(
@@ -226,13 +221,13 @@ def test_kto_trainer_without_providing_ref_model(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
@require_peft
def test_kto_trainer_without_providing_ref_model_with_lora(self):
@@ -274,14 +269,14 @@ def test_kto_trainer_without_providing_ref_model_with_lora(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
@require_no_wandb
def test_kto_trainer_generate_during_eval_no_wandb(self):
@@ -300,11 +295,8 @@ def test_kto_trainer_generate_during_eval_no_wandb(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")
- with self.assertRaisesRegex(
- ValueError,
- expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
- " Please install `wandb` or `comet-ml` to resolve.",
- ):
+ with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
+ " Please install `wandb` or `comet-ml` to resolve."):
KTOTrainer(
model=self.model,
ref_model=None,
@@ -365,7 +357,7 @@ def test_kto_lora_save(self):
try:
AutoModelForCausalLM.from_pretrained(self.tmp_dir)
except OSError:
- self.fail("Loading the saved peft adapter failed")
+ pytest.fail("Loading the saved peft adapter failed")
@require_liger_kernel
def test_kto_trainer_with_liger(self):
@@ -389,14 +381,14 @@ def test_kto_trainer_with_liger(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
def test_compute_metrics(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -432,4 +424,4 @@ def dummy_compute_metrics(*args, **kwargs):
trainer.train()
- self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
+ assert trainer.state.log_history[-2]["eval_test"] == 0.0
diff --git a/tests/test_modeling_geometric_mixture_wrapper.py b/tests/test_modeling_geometric_mixture_wrapper.py
index ae6f5010821..65553b79b77 100644
--- a/tests/test_modeling_geometric_mixture_wrapper.py
+++ b/tests/test_modeling_geometric_mixture_wrapper.py
@@ -40,9 +40,9 @@ def test_forward(self):
output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask)
- self.assertIsNotNone(output)
- self.assertTrue(hasattr(output, "logits"))
- self.assertEqual(output.logits.shape, (1, 5, self.model.config.vocab_size))
+ assert output is not None
+ assert hasattr(output, "logits")
+ assert output.logits.shape == (1, 5, self.model.config.vocab_size)
def test_mixture_coefficient(self):
input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device)
@@ -57,7 +57,7 @@ def test_mixture_coefficient(self):
self.mixture_coef * ref_model_output.logits + (1 - self.mixture_coef) * model_output.logits, dim=-1
)
- self.assertTrue(torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5))
+ assert torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5)
def test_prepare_inputs_for_generation(self):
input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device)
@@ -65,6 +65,6 @@ def test_prepare_inputs_for_generation(self):
inputs = self.wrapper.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask, use_cache=True)
- self.assertIn("input_ids", inputs)
- self.assertIn("attention_mask", inputs)
- self.assertFalse(inputs.get("use_cache", False))
+ assert "input_ids" in inputs
+ assert "attention_mask" in inputs
+ assert not inputs.get("use_cache", False)
diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py
index b0a75211175..355dcb2a1c8 100644
--- a/tests/test_modeling_value_head.py
+++ b/tests/test_modeling_value_head.py
@@ -65,7 +65,7 @@ def test_value_head(self):
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
- self.assertTrue(hasattr(model, "v_head"))
+ assert hasattr(model, "v_head")
def test_value_head_shape(self):
r"""
@@ -73,7 +73,7 @@ def test_value_head_shape(self):
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
- self.assertEqual(model.v_head.summary.weight.shape[0], 1)
+ assert model.v_head.summary.weight.shape[0] == 1
def test_value_head_init_random(self):
r"""
@@ -82,9 +82,7 @@ def test_value_head_init_random(self):
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
- self.assertFalse(
- torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))
- )
+ assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))
def test_value_head_not_str(self):
r"""
@@ -94,7 +92,7 @@ def test_value_head_not_str(self):
for model_name in self.all_model_names:
pretrained_model = self.transformers_model_class.from_pretrained(model_name)
model = self.trl_model_class.from_pretrained(pretrained_model)
- self.assertTrue(hasattr(model, "v_head"))
+ assert hasattr(model, "v_head")
def test_from_save_trl(self):
"""
@@ -110,7 +108,7 @@ def test_from_save_trl(self):
# Check if the weights are the same
for key in model_from_save.state_dict():
- self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]))
+ assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])
def test_from_save_trl_sharded(self):
"""
@@ -125,7 +123,7 @@ def test_from_save_trl_sharded(self):
# Check if the weights are the same
for key in model_from_save.state_dict():
- self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]))
+ assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])
def test_from_save_transformers_sharded(self):
"""
@@ -143,11 +141,9 @@ def test_from_save_transformers_sharded(self):
# Check if the weights are the same
for key in transformers_model.state_dict():
- self.assertTrue(
- torch.allclose(
+ assert torch.allclose(
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
)
- )
def test_from_save_transformers(self):
"""
@@ -166,27 +162,21 @@ def test_from_save_transformers(self):
# Check if the weights are the same
for key in transformers_model.state_dict():
- self.assertTrue(
- torch.allclose(
+ assert torch.allclose(
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
)
- )
# Check if the trl model has the same keys as the transformers model
# except the v_head
for key in trl_model.state_dict():
if "v_head" not in key:
- self.assertIn(key, transformers_model.state_dict())
+ assert key in transformers_model.state_dict()
# check if the weights are the same
- self.assertTrue(
- torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key])
- )
+ assert torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key])
# check if they have the same modules
- self.assertEqual(
- set(transformers_model_from_save.state_dict().keys()),
- set(transformers_model.state_dict().keys()),
- )
+ assert set(transformers_model_from_save.state_dict().keys()) == \
+ set(transformers_model.state_dict().keys())
class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase):
@@ -217,7 +207,7 @@ def test_inference(self):
# Check if the outputs are of the right size - here
# we always output 3 values - logits, loss, and value states
- self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE)
+ assert len(outputs) == EXPECTED_OUTPUT_SIZE
def test_dropout_config(self):
r"""
@@ -229,7 +219,7 @@ def test_dropout_config(self):
model = self.trl_model_class.from_pretrained(pretrained_model)
# Check if v head of the model has the same dropout as the config
- self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob)
+ assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob
def test_dropout_kwargs(self):
r"""
@@ -241,12 +231,12 @@ def test_dropout_kwargs(self):
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
# Check if v head of the model has the same dropout as the config
- self.assertEqual(model.v_head.dropout.p, 0.5)
+ assert model.v_head.dropout.p == 0.5
model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5)
# Check if v head of the model has the same dropout as the config
- self.assertEqual(model.v_head.dropout.p, 0.5)
+ assert model.v_head.dropout.p == 0.5
@parameterized.expand(ALL_CAUSAL_LM_MODELS)
def test_generate(self, model_name):
@@ -271,14 +261,12 @@ def test_transformers_bf16_kwargs(self):
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
- self.assertTrue(
- any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings),
- "Can't test the model because it doesn't have any of the expected lm_head namings",
- )
+ assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), \
+ "Can't test the model because it doesn't have any of the expected lm_head namings"
for lm_head_naming in lm_head_namings:
if hasattr(trl_model.pretrained_model, lm_head_naming):
- self.assertEqual(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype, torch.bfloat16)
+ assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device)
@@ -296,13 +284,11 @@ def test_push_to_hub(self):
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo")
# check all keys
- self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys())
+ assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
for name, param in model.state_dict().items():
- self.assertTrue(
- torch.allclose(param, model_from_pretrained.state_dict()[name]),
- f"Parameter {name} is not the same after push_to_hub and from_pretrained",
- )
+ assert torch.allclose(param, model_from_pretrained.state_dict()[name]), \
+ f"Parameter {name} is not the same after push_to_hub and from_pretrained"
class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase):
@@ -334,7 +320,7 @@ def test_inference(self):
# Check if the outputs are of the right size - here
# we always output 3 values - logits, loss, and value states
- self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE)
+ assert len(outputs) == EXPECTED_OUTPUT_SIZE
def test_dropout_config(self):
r"""
@@ -346,7 +332,7 @@ def test_dropout_config(self):
model = self.trl_model_class.from_pretrained(pretrained_model)
# Check if v head of the model has the same dropout as the config
- self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob)
+ assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob
def test_dropout_kwargs(self):
r"""
@@ -358,12 +344,12 @@ def test_dropout_kwargs(self):
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
# Check if v head of the model has the same dropout as the config
- self.assertEqual(model.v_head.dropout.p, 0.5)
+ assert model.v_head.dropout.p == 0.5
model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5)
# Check if v head of the model has the same dropout as the config
- self.assertEqual(model.v_head.dropout.p, 0.5)
+ assert model.v_head.dropout.p == 0.5
@parameterized.expand(ALL_SEQ2SEQ_MODELS)
def test_generate(self, model_name):
@@ -389,13 +375,11 @@ def test_push_to_hub(self):
model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo")
# check all keys
- self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys())
+ assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
for name, param in model.state_dict().items():
- self.assertTrue(
- torch.allclose(param, model_from_pretrained.state_dict()[name]),
- f"Parameter {name} is not the same after push_to_hub and from_pretrained",
- )
+ assert torch.allclose(param, model_from_pretrained.state_dict()[name]), \
+ f"Parameter {name} is not the same after push_to_hub and from_pretrained"
def test_transformers_bf16_kwargs(self):
r"""
@@ -408,13 +392,11 @@ def test_transformers_bf16_kwargs(self):
lm_head_namings = self.trl_model_class.lm_head_namings
- self.assertTrue(
- any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
- )
+ assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
for lm_head_naming in lm_head_namings:
if hasattr(trl_model.pretrained_model, lm_head_naming):
- self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16)
+ assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device)
@@ -453,16 +435,16 @@ def test_independent_reference(self):
last_ref_layer_after = ref_model.get_parameter(layer_1).data.clone()
# before optimization ref and model are identical
- self.assertTrue((first_layer_before == first_ref_layer_before).all())
- self.assertTrue((last_layer_before == last_ref_layer_before).all())
+ assert (first_layer_before == first_ref_layer_before).all()
+ assert (last_layer_before == last_ref_layer_before).all()
# ref model stays identical after optimization
- self.assertTrue((first_ref_layer_before == first_ref_layer_after).all())
- self.assertTrue((last_ref_layer_before == last_ref_layer_after).all())
+ assert (first_ref_layer_before == first_ref_layer_after).all()
+ assert (last_ref_layer_before == last_ref_layer_after).all()
# optimized model changes
- self.assertFalse((first_layer_before == first_layer_after).all())
- self.assertFalse((last_layer_before == last_layer_after).all())
+ assert not (first_layer_before == first_layer_after).all()
+ assert not (last_layer_before == last_layer_after).all()
def test_shared_layers(self):
layer_0 = self.layer_format.format(layer=0)
@@ -487,15 +469,15 @@ def test_shared_layers(self):
second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone()
# before optimization ref and model are identical
- self.assertTrue((first_layer_before == first_ref_layer_before).all())
- self.assertTrue((second_layer_before == second_ref_layer_before).all())
+ assert (first_layer_before == first_ref_layer_before).all()
+ assert (second_layer_before == second_ref_layer_before).all()
# ref model stays identical after optimization
- self.assertTrue((first_ref_layer_before == first_ref_layer_after).all())
- self.assertTrue((second_ref_layer_before == second_ref_layer_after).all())
+ assert (first_ref_layer_before == first_ref_layer_after).all()
+ assert (second_ref_layer_before == second_ref_layer_after).all()
# first layer of optimized model stays the same
- self.assertTrue((first_layer_before == first_layer_after).all())
+ assert (first_layer_before == first_layer_after).all()
# other layers in optimized model change
- self.assertFalse((second_layer_before == second_layer_after).all())
+ assert not (second_layer_before == second_layer_after).all()
diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py
index 4550c35e1d1..25caaff3d38 100644
--- a/tests/test_nash_md_trainer.py
+++ b/tests/test_nash_md_trainer.py
@@ -65,7 +65,7 @@ def test_nash_md_trainer_training(self, config_name):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_with_peft(self):
@@ -93,7 +93,7 @@ def test_training_with_peft(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_with_peft_and_ref_model(self):
@@ -122,7 +122,7 @@ def test_training_with_peft_and_ref_model(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_with_peft_model_and_peft_config(self):
@@ -153,7 +153,7 @@ def test_training_with_peft_model_and_peft_config(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_pre_pefted_model_implicit_ref_with_reward_model(self):
@@ -184,7 +184,7 @@ def test_training_pre_pefted_model_implicit_ref_with_reward_model(self):
trainer.train()
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
@require_llm_blender
@@ -215,4 +215,4 @@ def test_nash_md_trainer_judge_training(self, config_name):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py
index 336db8a089f..b8eb83a0f04 100644
--- a/tests/test_online_dpo_trainer.py
+++ b/tests/test_online_dpo_trainer.py
@@ -73,7 +73,7 @@ def test_training(self, config_name):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
def test_training_model_str(self):
training_args = OnlineDPOConfig(
@@ -98,7 +98,7 @@ def test_training_model_str(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
def test_training_with_ref_model(self):
training_args = OnlineDPOConfig(
@@ -124,7 +124,7 @@ def test_training_with_ref_model(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
def test_ref_model_is_model(self):
training_args = OnlineDPOConfig(
@@ -136,7 +136,7 @@ def test_ref_model_is_model(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
OnlineDPOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
@@ -174,7 +174,7 @@ def test_training_with_peft(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_with_peft_and_ref_model(self):
@@ -204,7 +204,7 @@ def test_training_with_peft_and_ref_model(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_with_peft_model_and_peft_config(self):
@@ -236,7 +236,7 @@ def test_training_with_peft_model_and_peft_config(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_llm_blender
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
@@ -262,7 +262,7 @@ def test_training_with_judge(self, config_name):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
@require_torch_accelerator
@@ -293,7 +293,7 @@ def test_training_with_vllm(self, config_name):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_vllm
def test_training_with_vllm_colocate(self):
@@ -330,57 +330,57 @@ def test_training_with_vllm_colocate(self):
)
# Verify vLLM setup
- self.assertTrue(trainer.use_vllm)
- self.assertEqual(trainer.vllm_mode, "colocate")
- self.assertIsNotNone(trainer.llm)
+ assert trainer.use_vllm
+ assert trainer.vllm_mode == "colocate"
+ assert trainer.llm is not None
# self.assertIsNone(trainer.vllm_client)
# self.assertEqual(trainer.vllm_gpu_memory_utilization, 0.2)
# Verify generation parameters
- self.assertEqual(trainer.temperature, 0.9)
- self.assertEqual(trainer.top_p, 0.95)
- self.assertEqual(trainer.top_k, 50)
- self.assertEqual(trainer.repetition_penalty, 1.1)
+ assert trainer.temperature == 0.9
+ assert trainer.top_p == 0.95
+ assert trainer.top_k == 50
+ assert trainer.repetition_penalty == 1.1
# Verify generation config
- self.assertIsNotNone(trainer.generation_config)
- self.assertEqual(trainer.generation_config.temperature, 0.9)
- self.assertEqual(trainer.generation_config.top_p, 0.95)
- self.assertEqual(trainer.generation_config.top_k, 50)
- self.assertEqual(trainer.generation_config.repetition_penalty, 1.1)
- self.assertEqual(trainer.generation_config.max_tokens, 32)
+ assert trainer.generation_config is not None
+ assert trainer.generation_config.temperature == 0.9
+ assert trainer.generation_config.top_p == 0.95
+ assert trainer.generation_config.top_k == 50
+ assert trainer.generation_config.repetition_penalty == 1.1
+ assert trainer.generation_config.max_tokens == 32
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
def test_vllm_config_validation(self):
"""Test vLLM configuration validation"""
# Test valid vllm_mode values
config = OnlineDPOConfig(use_vllm=True, vllm_mode="server")
- self.assertEqual(config.vllm_mode, "server")
+ assert config.vllm_mode == "server"
config = OnlineDPOConfig(use_vllm=True, vllm_mode="colocate")
- self.assertEqual(config.vllm_mode, "colocate")
+ assert config.vllm_mode == "colocate"
# Test default values
config = OnlineDPOConfig()
- self.assertEqual(config.vllm_mode, "server")
- self.assertIsNone(config.vllm_server_base_url)
- self.assertEqual(config.vllm_server_host, "0.0.0.0")
- self.assertEqual(config.vllm_server_port, 8000)
- self.assertEqual(config.vllm_server_timeout, 240.0)
- self.assertEqual(config.vllm_gpu_memory_utilization, 0.55)
+ assert config.vllm_mode == "server"
+ assert config.vllm_server_base_url is None
+ assert config.vllm_server_host == "0.0.0.0"
+ assert config.vllm_server_port == 8000
+ assert config.vllm_server_timeout == 240.0
+ assert config.vllm_gpu_memory_utilization == 0.55
# Test generation parameters
- self.assertEqual(config.top_p, 1.0)
- self.assertIsNone(config.top_k)
- self.assertIsNone(config.min_p)
- self.assertEqual(config.repetition_penalty, 1.0)
- self.assertFalse(config.use_transformers_paged)
- self.assertIsNone(config.cache_implementation)
- self.assertIsNone(config.generation_kwargs)
+ assert config.top_p == 1.0
+ assert config.top_k is None
+ assert config.min_p is None
+ assert config.repetition_penalty == 1.0
+ assert not config.use_transformers_paged
+ assert config.cache_implementation is None
+ assert config.generation_kwargs is None
def test_generation_config_setup(self):
"""Test that generation configuration is properly set up for both vLLM and transformers"""
@@ -407,17 +407,17 @@ def test_generation_config_setup(self):
)
# Verify transformers generation config
- self.assertFalse(trainer.use_vllm)
+ assert not trainer.use_vllm
# When not using vLLM, these attributes should not be set
- self.assertFalse(hasattr(trainer, "llm") and trainer.llm is not None)
- self.assertFalse(hasattr(trainer, "vllm_client") and trainer.vllm_client is not None)
- self.assertIsNotNone(trainer.generation_config)
- self.assertEqual(trainer.generation_config.temperature, 0.8)
- self.assertEqual(trainer.generation_config.top_p, 0.9)
- self.assertEqual(trainer.generation_config.top_k, 40)
- self.assertEqual(trainer.generation_config.repetition_penalty, 1.2)
- self.assertEqual(trainer.generation_config.max_new_tokens, 64)
- self.assertFalse(trainer.generation_config.do_sample) # From generation_kwargs
+ assert not (hasattr(trainer, "llm") and trainer.llm is not None)
+ assert not (hasattr(trainer, "vllm_client") and trainer.vllm_client is not None)
+ assert trainer.generation_config is not None
+ assert trainer.generation_config.temperature == 0.8
+ assert trainer.generation_config.top_p == 0.9
+ assert trainer.generation_config.top_k == 40
+ assert trainer.generation_config.repetition_penalty == 1.2
+ assert trainer.generation_config.max_new_tokens == 64
+ assert not trainer.generation_config.do_sample # From generation_kwargs
@require_torch_accelerator
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
@@ -447,7 +447,7 @@ def test_training_with_transformers_paged(self, config_name):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
def test_training_with_reward_funcs(self, config_name):
@@ -475,11 +475,11 @@ def simple_reward_func(prompts, completions, completion_ids, **kwargs):
)
trainer.train()
- self.assertIn("train_loss", trainer.state.log_history[-1])
- self.assertEqual(len(trainer.reward_funcs), 2)
- self.assertIsNotNone(trainer.reward_weights)
- self.assertAlmostEqual(trainer.reward_weights[0].item(), 0.7, places=5)
- self.assertAlmostEqual(trainer.reward_weights[1].item(), 0.3, places=5)
+ assert "train_loss" in trainer.state.log_history[-1]
+ assert len(trainer.reward_funcs) == 2
+ assert trainer.reward_weights is not None
+ assert round(abs(trainer.reward_weights[0].item()-0.7), 5) == 0
+ assert round(abs(trainer.reward_weights[1].item()-0.3), 5) == 0
@require_vision
@@ -531,4 +531,4 @@ def test_online_dpo_vlm_trainer(self, model_id):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py
index 5898ac8d7dd..2f444eb4dfc 100644
--- a/tests/test_orpo_trainer.py
+++ b/tests/test_orpo_trainer.py
@@ -82,13 +82,13 @@ def test_orpo_trainer(self, name, config_name):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
@parameterized.expand(
[
@@ -137,14 +137,14 @@ def test_orpo_trainer_with_lora(self, config_name):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.equal(param, new_param))
+ assert not torch.equal(param, new_param)
def test_compute_metrics(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -178,4 +178,4 @@ def dummy_compute_metrics(*args, **kwargs):
trainer.train()
- self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
+ assert trainer.state.log_history[-2]["eval_test"] == 0.0
diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py
index 0543ee31c3c..3b68be8b817 100644
--- a/tests/test_peft_models.py
+++ b/tests/test_peft_models.py
@@ -63,7 +63,7 @@ def test_peft_requires_grad(self):
model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model)
# Check that the value head has requires_grad=True
- self.assertTrue(model.v_head.summary.weight.requires_grad)
+ assert model.v_head.summary.weight.requires_grad
def test_check_peft_model_nb_trainable_params(self):
r"""
@@ -76,12 +76,12 @@ def test_check_peft_model_nb_trainable_params(self):
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- self.assertEqual(nb_trainable_params, 905)
+ assert nb_trainable_params == 905
# Check that the number of trainable param for the non-peft model is correct
non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad)
- self.assertEqual(nb_trainable_params, 2428641)
+ assert nb_trainable_params == 2428641
def test_create_peft_model_from_config(self):
r"""
@@ -92,13 +92,13 @@ def test_create_peft_model_from_config(self):
)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
- self.assertEqual(nb_trainable_params, 905)
+ assert nb_trainable_params == 905
causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id)
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
- self.assertEqual(nb_trainable_params, 905)
+ assert nb_trainable_params == 905
@require_torch_gpu_if_bnb_not_multi_backend_enabled
def test_create_bnb_peft_model_from_config(self):
@@ -112,8 +112,8 @@ def test_create_bnb_peft_model_from_config(self):
)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
- self.assertEqual(nb_trainable_params, 905)
- self.assertIsInstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt)
+ assert nb_trainable_params == 905
+ assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt)
causal_lm_model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id, load_in_8bit=True, device_map="auto"
@@ -121,8 +121,8 @@ def test_create_bnb_peft_model_from_config(self):
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
- self.assertEqual(nb_trainable_params, 905)
- self.assertIsInstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt)
+ assert nb_trainable_params == 905
+ assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt)
def test_save_pretrained_peft(self):
r"""
@@ -136,31 +136,23 @@ def test_save_pretrained_peft(self):
model.save_pretrained(self.tmp_dir)
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
- self.assertTrue(
- os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"),
- f"{self.tmp_dir}/adapter_model.safetensors does not exist",
- )
- self.assertTrue(
- os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist"
- )
+ assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), \
+ f"{self.tmp_dir}/adapter_model.safetensors does not exist"
+ assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist"
# check also for `pytorch_model.bin` and make sure it only contains `v_head` weights
- self.assertTrue(
- os.path.exists(f"{self.tmp_dir}/pytorch_model.bin"), f"{self.tmp_dir}/pytorch_model.bin does not exist"
- )
+ assert os.path.exists(f"{self.tmp_dir}/pytorch_model.bin"), f"{self.tmp_dir}/pytorch_model.bin does not exist"
# check that only keys that starts with `v_head` are in the dict
maybe_v_head = torch.load(f"{self.tmp_dir}/pytorch_model.bin", weights_only=True)
- self.assertTrue(
- all(k.startswith("v_head") for k in maybe_v_head.keys()),
- f"keys in {self.tmp_dir}/pytorch_model.bin do not start with `v_head`",
- )
+ assert all(k.startswith("v_head") for k in maybe_v_head.keys()), \
+ f"keys in {self.tmp_dir}/pytorch_model.bin do not start with `v_head`"
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir)
# check all the weights are the same
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
- self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}")
+ assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
def test_load_pretrained_peft(self):
r"""
@@ -175,18 +167,14 @@ def test_load_pretrained_peft(self):
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir)
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
- self.assertTrue(
- os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"),
- f"{self.tmp_dir}/adapter_model.safetensors does not exist",
- )
- self.assertTrue(
- os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist"
- )
+ assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), \
+ f"{self.tmp_dir}/adapter_model.safetensors does not exist"
+ assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist"
# check all the weights are the same
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]:
- self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}")
+ assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
def test_continue_training_peft_model(self):
r"""
@@ -200,4 +188,4 @@ def test_continue_training_peft_model(self):
model = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir, is_trainable=True)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- self.assertEqual(nb_trainable_params, 905)
+ assert nb_trainable_params == 905
diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py
index f8b95a8e5ff..13354c36b29 100644
--- a/tests/test_ppo_trainer.py
+++ b/tests/test_ppo_trainer.py
@@ -107,8 +107,8 @@ def test_basic_training(self):
policy_weights_updated = True
break
- self.assertTrue(critic_weights_updated, "Critic weights were not updated during training")
- self.assertTrue(policy_weights_updated, "Policy weights were not updated during training")
+ assert critic_weights_updated, "Critic weights were not updated during training"
+ assert policy_weights_updated, "Policy weights were not updated during training"
@require_peft
def test_peft_training(self):
@@ -171,5 +171,5 @@ def test_peft_training(self):
policy_weights_updated = True
break
- self.assertTrue(critic_weights_updated, "Critic weights were not updated during training")
- self.assertTrue(policy_weights_updated, "Policy LoRA weights were not updated during training")
+ assert critic_weights_updated, "Critic weights were not updated during training"
+ assert policy_weights_updated, "Policy LoRA weights were not updated during training"
diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py
index e26428c5203..76398519836 100644
--- a/tests/test_prm_trainer.py
+++ b/tests/test_prm_trainer.py
@@ -75,13 +75,11 @@ def test_tokenize_row_no_truncation(self):
is_eval=False,
)
- self.assertEqual(
- result,
+ assert result == \
{
"input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
"labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0],
- },
- )
+ }
def test_tokenize_row_train_on_last_step_only(self):
# Define the input features
@@ -102,13 +100,11 @@ def test_tokenize_row_train_on_last_step_only(self):
is_eval=False,
)
- self.assertEqual(
- result,
+ assert result == \
{
"input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
"labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0],
- },
- )
+ }
def test_tokenize_row_prompt_truncation(self):
# Define the input features
@@ -130,13 +126,11 @@ def test_tokenize_row_prompt_truncation(self):
is_eval=False,
)
- self.assertEqual(
- result,
+ assert result == \
{
"input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
"labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0],
- },
- )
+ }
def test_tokenize_row_completion_truncation(self):
# Define the input features
@@ -158,13 +152,11 @@ def test_tokenize_row_completion_truncation(self):
is_eval=False,
)
- self.assertEqual(
- result,
+ assert result == \
{
"input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11],
"labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100],
- },
- )
+ }
def test_tokenize_row_prompt_completion_truncation(self):
# Define the input features
@@ -186,13 +178,11 @@ def test_tokenize_row_prompt_completion_truncation(self):
is_eval=False,
)
- self.assertEqual(
- result,
+ assert result == \
{
"input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030],
"labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1],
- },
- )
+ }
def test_tokenize_row_multi_token_separator(self):
# Define the input features
@@ -214,13 +204,11 @@ def test_tokenize_row_multi_token_separator(self):
is_eval=False,
)
- self.assertEqual(
- result,
+ assert result == \
{
"input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030],
"labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0],
- },
- )
+ }
class PRMTrainerTester(TrlTestCase):
@@ -244,12 +232,12 @@ def test_train_full(self, train_on_last_step_only):
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
def test_train_full_pretokenized(self):
dummy_dataset = Dataset.from_dict(
@@ -297,12 +285,12 @@ def test_train_full_pretokenized(self):
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
@require_peft
def test_train_lora(self):
@@ -337,17 +325,17 @@ def test_train_lora(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ assert trainer.state.log_history[(-1)]["train_loss"] is not None
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+ assert not torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)
# Check that the non trainable parameters have not changed
for n, param in previous_non_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+ assert torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)
def test_tags(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train")
@@ -355,4 +343,4 @@ def test_tags(self):
trainer = PRMTrainer(
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
)
- self.assertEqual(trainer.model.model_tags, trainer._tag_names)
+ assert trainer.model.model_tags == trainer._tag_names
diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py
index b4d53e16941..ba8300c78ff 100644
--- a/tests/test_reward_trainer.py
+++ b/tests/test_reward_trainer.py
@@ -131,12 +131,12 @@ def test_train(self, model_id):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@parameterized.expand(
[
@@ -165,12 +165,12 @@ def test_train_dataset_types(self, config_name):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_model(self):
# Instantiate the model
@@ -192,12 +192,12 @@ def test_train_model(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_from_causal_lm(self):
# Get the dataset
@@ -216,12 +216,12 @@ def test_train_from_causal_lm(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_model_dtype(self):
# Get the dataset
@@ -247,7 +247,7 @@ def test_train_model_dtype(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
@@ -257,8 +257,8 @@ def test_train_model_dtype(self):
continue
new_param = trainer.model.get_parameter(n)
# Check the torch dtype
- self.assertEqual(new_param.dtype, torch.float16)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert new_param.dtype == torch.float16
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_dense_with_peft_config(self):
@@ -287,15 +287,15 @@ def test_train_dense_with_peft_config(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_moe_with_peft_config(self):
@@ -324,15 +324,15 @@ def test_train_moe_with_peft_config(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_peft_model(self):
@@ -361,15 +361,15 @@ def test_train_peft_model(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_dense_with_peft_config_and_gradient_checkpointing(self):
@@ -398,15 +398,15 @@ def test_train_dense_with_peft_config_and_gradient_checkpointing(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_moe_with_peft_config_and_gradient_checkpointing(self):
@@ -435,15 +435,15 @@ def test_train_moe_with_peft_config_and_gradient_checkpointing(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_with_peft_model_and_gradient_checkpointing(self):
@@ -462,7 +462,7 @@ def test_train_with_peft_model_and_gradient_checkpointing(self):
trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
# Verify model is a PeftModel
- self.assertIsInstance(trainer.model, PeftModel)
+ assert isinstance(trainer.model, PeftModel)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
@@ -471,15 +471,15 @@ def test_train_with_peft_model_and_gradient_checkpointing(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_pretokenized_data(self):
# Get the dataset
@@ -507,12 +507,12 @@ def tokenize_example(example):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_iterable_dataset(self):
# Get the dataset
@@ -535,12 +535,12 @@ def test_train_with_iterable_dataset(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_chat_template_kwargs(self):
# Get the dataset
@@ -569,12 +569,12 @@ def test_train_with_chat_template_kwargs(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_set_chat_template_from_model(self):
# Get the dataset
@@ -596,7 +596,7 @@ def test_train_with_set_chat_template_from_model(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
@@ -606,7 +606,7 @@ def test_train_with_set_chat_template_from_model(self):
# this parameter.
if n == "gpt_neox.final_layer_norm.bias":
continue
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_set_chat_template_from_path(self):
# Get the dataset
@@ -632,7 +632,7 @@ def test_train_with_set_chat_template_from_path(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
@@ -642,19 +642,17 @@ def test_train_with_set_chat_template_from_path(self):
# this parameter.
if n == "gpt_neox.final_layer_norm.bias":
continue
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
# Check that the template saved in the output directory is the same as the one used for training
template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja"
- self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}")
+ assert template_path.exists(), f"Chat template not found at {template_path}"
with open(template_path) as f:
template_content = f.read()
with open(training_args.chat_template_path) as f:
original_template_content = f.read()
- self.assertEqual(
- template_content, original_template_content, "Chat template content does not match the original"
- )
+ assert template_content == original_template_content, "Chat template content does not match the original"
@unittest.skip("Skipping until we have a dataset with tool calls")
def test_train_toolcall_data(self):
@@ -676,12 +674,12 @@ def test_train_toolcall_data(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_eval(self):
# Get the dataset
@@ -700,7 +698,7 @@ def test_train_with_eval(self):
trainer.train()
# Check that the eval loss is not None
- self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
+ assert trainer.state.log_history[0]["eval_loss"] is not None
def test_train_with_multiple_eval_dataset(self):
# Get the dataset
@@ -718,8 +716,8 @@ def test_train_with_multiple_eval_dataset(self):
trainer.train()
# Check that the eval losses are not None
- self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"])
- self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"])
+ assert trainer.state.log_history[-3]["eval_data1_loss"] is not None
+ assert trainer.state.log_history[-2]["eval_data2_loss"] is not None
def test_train_with_gradient_checkpointing(self):
# Get the dataset
@@ -740,12 +738,12 @@ def test_train_with_gradient_checkpointing(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_tag_added(self):
# Get the dataset
@@ -758,7 +756,7 @@ def test_tag_added(self):
)
for tag in ["reward-trainer", "trl"]:
- self.assertIn(tag, trainer.model.model_tags)
+ assert tag in trainer.model.model_tags
@require_peft
def test_tag_added_peft(self):
@@ -773,7 +771,7 @@ def test_tag_added_peft(self):
)
for tag in ["reward-trainer", "trl"]:
- self.assertIn(tag, trainer.model.model_tags)
+ assert tag in trainer.model.model_tags
def test_train_with_margin(self):
# Get the dataset
@@ -800,12 +798,12 @@ def add_margin(example):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_center_rewards_coefficient(self):
# Get the dataset
@@ -826,9 +824,9 @@ def test_train_with_center_rewards_coefficient(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index 8b20a0ff7e9..aac6aabca0c 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -31,7 +31,7 @@ def test_valid_format(self):
completions = [[{"content": completion}] for completion in completions]
expected_rewards = [1.0, 1.0, 1.0, 1.0, 1.0] # All should be valid
rewards = think_format_reward(completions)
- self.assertEqual(rewards, expected_rewards)
+ assert rewards == expected_rewards
def test_invalid_format(self):
completions = [
@@ -48,7 +48,7 @@ def test_invalid_format(self):
completions = [[{"content": completion}] for completion in completions]
expected_rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # All should be invalid
rewards = think_format_reward(completions)
- self.assertEqual(rewards, expected_rewards)
+ assert rewards == expected_rewards
def test_mixed_format(self):
completions = [
@@ -60,7 +60,7 @@ def test_mixed_format(self):
completions = [[{"content": completion}] for completion in completions]
expected_rewards = [1.0, 1.0, 0.0, 0.0]
rewards = think_format_reward(completions)
- self.assertEqual(rewards, expected_rewards)
+ assert rewards == expected_rewards
class SoftOverlongPunishmentRewardTester(unittest.TestCase):
@@ -70,7 +70,7 @@ def test_soft_overlong_punishment_short_completion(self):
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 50] # 50 <= 80
rewards = reward_fn(completion_ids=completion_ids)
- self.assertEqual(rewards, [0])
+ assert rewards == [0]
def test_soft_overlong_punishment_long_completion(self):
"""Test soft overlong punishment reward function with a longer than max completion."""
@@ -78,14 +78,14 @@ def test_soft_overlong_punishment_long_completion(self):
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 110]
rewards = reward_fn(completion_ids)
- self.assertEqual(rewards, [-1])
+ assert rewards == [-1]
def test_soft_overlong_punishment_intermediate_completion(self):
"""Test soft overlong punishment reward function for intermediate length completion."""
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 90] # 90 is between 80 and 100
rewards = reward_fn(completion_ids)
- self.assertAlmostEqual(rewards[0], -0.5, places=4)
+ assert round(abs(rewards[0]--0.5), 4) == 0
if __name__ == "__main__":
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index cde52de6047..2ddb51248ce 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -30,6 +30,7 @@
from trl import RLOOConfig, RLOOTrainer
from .testing_utils import TrlTestCase, require_vllm
+import pytest
if is_peft_available():
@@ -69,12 +70,12 @@ def test_training(self, config_name):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_eval(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
@@ -122,12 +123,12 @@ def test_training_multiple_iterations(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_peft
def test_training_peft(self):
@@ -155,15 +156,15 @@ def test_training_peft(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model params to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed."
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
@require_peft
def test_training_peft_with_gradient_checkpointing(self):
@@ -197,22 +198,22 @@ def test_training_peft_with_gradient_checkpointing(self):
)
# Verify gradient checkpointing is enabled
- self.assertIsInstance(trainer.model, PeftModel)
+ assert isinstance(trainer.model, PeftModel)
# Store initial parameters to check which ones change
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that only LoRA parameters have changed, base model parameters remain unchanged
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "lora" in n.lower(): # LoRA parameters should change
- self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"LoRA parameter {n} has not changed."
else: # Base model parameters should not change
- self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.")
+ assert torch.equal(param, new_param), f"Base parameter {n} has changed."
def test_training_different_reward_model(self):
# Use a reward model different from the model: different chat template, tokenization, etc.
@@ -246,12 +247,12 @@ def test_training_different_reward_model(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_reward_func_standard(self):
# Test if trainer can handle reward function with standard format
@@ -280,12 +281,12 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_reward_func_conversational(self):
# Test if trainer can handle reward function with conversational format
@@ -315,12 +316,12 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_reward_funcs(self):
# Test that RLOOTrainer can be instantiated with multiple reward functions
@@ -353,12 +354,12 @@ def reward_func2(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_reward_funcs_with_None_output(self):
"""Test that a valid math reward function is processed correctly while the code reward function returns None."""
@@ -397,12 +398,12 @@ def non_applicable_reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_reward_funcs_with_weights(self):
"""Test that RLOOTrainer can handle multiple reward functions with weights."""
@@ -437,16 +438,16 @@ def reward_func2(completions, **kwargs):
trainer.train()
# Check that training logs contain both reward metrics
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
- self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1])
- self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1])
- self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1])
- self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ assert "rewards/reward_func1/mean" in trainer.state.log_history[-1]
+ assert "rewards/reward_func1/std" in trainer.state.log_history[-1]
+ assert "rewards/reward_func2/mean" in trainer.state.log_history[-1]
+ assert "rewards/reward_func2/std" in trainer.state.log_history[-1]
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_mixed_reward_funcs(self):
# Test if the trainer can handle a mix of reward functions and reward models
@@ -475,12 +476,12 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_reward_func_additional_column(self):
# Test if trainer can handle reward function that rely on additional columns in the dataset
@@ -513,12 +514,12 @@ def reward_func(completions, some_values, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_sync_ref_model(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -544,12 +545,12 @@ def test_training_with_sync_ref_model(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_beta_zero(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -573,12 +574,12 @@ def test_training_beta_zero(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@unittest.skip("We should add a mock for the vLLM server.")
@require_peft
@@ -615,16 +616,16 @@ def test_training_vllm_and_peft(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model params to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed."
elif "base_layer" not in n and "original_module" not in n:
# We expect the peft params to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
@require_vllm
@unittest.skip("We should add a mock for the vLLM server.")
@@ -653,12 +654,12 @@ def test_training_vllm_guided_decoding(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_additional_generation_kwargs(self):
"""Test that training works with additional generation kwargs."""
@@ -688,12 +689,12 @@ def test_training_with_additional_generation_kwargs(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vllm
@unittest.skip("We should add a mock for the vLLM server.")
@@ -726,12 +727,12 @@ def test_training_vllm_with_additional_generation_kwargs(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_normalized_advantages(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -756,12 +757,12 @@ def test_training_with_normalized_advantages(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_clipped_rewards(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -786,12 +787,12 @@ def test_training_with_clipped_rewards(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@patch("transformers.generation.utils.GenerationMixin.generate")
def test_training_with_mask_truncated_completions(self, mock_generate):
@@ -836,12 +837,12 @@ def fake_generate(input_ids, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_mask_truncated_completions_all_masked(self):
"""
@@ -874,12 +875,12 @@ def test_training_with_mask_truncated_completions_all_masked(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.")
+ assert torch.equal(param, new_param), f"Parameter {n} has changed."
def test_warning_raised_all_rewards_none(self):
"""Test that a proper warning is raised when all rewards are None."""
@@ -908,7 +909,7 @@ def always_none_reward_func(completions, **kwargs):
trainer.train()
expected_warning = "All reward functions returned None for the following kwargs:"
- self.assertIn(expected_warning, cm.output[0])
+ assert expected_warning in cm.output[0]
def test_training_num_generations_larger_than_batch_size(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -933,12 +934,12 @@ def test_training_num_generations_larger_than_batch_size(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_multiple_dataloader_workers(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -963,12 +964,12 @@ def test_training_multiple_dataloader_workers(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_generation_kwargs(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -993,12 +994,12 @@ def test_training_with_generation_kwargs(self):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_training_with_reward_func_accessing_trainer_state(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1070,7 +1071,7 @@ def test_prepare_input_called_with_correct_data(self):
with patch.object(RLOOTrainer, "training_step", wraps=trainer.training_step) as mock_prepare:
trainer.train()
# 3 epochs * 2 iterations * 2 generation batches to cover the dataset * 4 steps_per_generation
- self.assertEqual(mock_prepare.call_count, 48)
+ assert mock_prepare.call_count == 48
for i in range(0, 8): # Generation batch repeated 8 times (steps_per_generation*num_iterations)
assert mock_prepare.call_args_list[i].args[1] == expected_first_generation_batch
for i in range(8, 16):
@@ -1113,7 +1114,7 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
@@ -1128,7 +1129,7 @@ def reward_func(completions, **kwargs):
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
def test_training_vlm_beta_non_zero(self):
@@ -1158,7 +1159,7 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
@@ -1168,7 +1169,7 @@ def reward_func(completions, **kwargs):
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
@require_peft
@@ -1203,15 +1204,15 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model params to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed."
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
@require_vision
def test_training_vlm_and_prompt_truncation(self):
@@ -1242,7 +1243,7 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
@@ -1252,7 +1253,7 @@ def reward_func(completions, **kwargs):
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
@require_vllm
@@ -1292,11 +1293,11 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
def test_training_vlm_multi_image(self):
@@ -1329,11 +1330,11 @@ def reward_func(completions, **kwargs):
trainer.train()
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
def test_mismatched_reward_processing_classes_length(self):
"""Test that mismatched length between reward_funcs and reward_processing_classes raises error."""
@@ -1352,7 +1353,7 @@ def test_mismatched_reward_processing_classes_length(self):
training_args = RLOOConfig(output_dir=self.tmp_dir, report_to="none")
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as context:
RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_models,
@@ -1361,7 +1362,7 @@ def test_mismatched_reward_processing_classes_length(self):
train_dataset=dataset,
)
- self.assertIn("must match", str(context.exception))
+ assert "must match" in str(context.exception)
def test_correct_reward_processing_classes_list(self):
"""Test that correct list of reward_processing_classes works properly."""
@@ -1392,7 +1393,7 @@ def test_correct_reward_processing_classes_list(self):
train_dataset=dataset,
)
- self.assertEqual(len(trainer.reward_processing_classes), len(reward_models))
+ assert len(trainer.reward_processing_classes) == len(reward_models)
def test_single_reward_model_with_single_processing_class(self):
"""Test that single reward model with single processing class works."""
@@ -1416,8 +1417,8 @@ def test_single_reward_model_with_single_processing_class(self):
train_dataset=dataset,
)
- self.assertEqual(len(trainer.reward_processing_classes), 1)
- self.assertEqual(trainer.reward_processing_classes[0], single_processing_class)
+ assert len(trainer.reward_processing_classes) == 1
+ assert trainer.reward_processing_classes[0] == single_processing_class
if __name__ == "__main__":
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index 19d5a1b5d70..2a558d6093f 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -64,7 +64,7 @@ def test_basic_padding(self):
result = self.collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
+ assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
@@ -79,7 +79,7 @@ def test_completion_mask(self):
result = self.collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
+ assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
@@ -95,7 +95,7 @@ def test_completion_only_loss_disabled(self):
result = collator(examples)
# Labels should not be masked when completion_only_loss=False
- self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
+ assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
@@ -107,7 +107,7 @@ def test_padding_free_mode(self):
result = collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
+ assert set(result.keys()) == {"input_ids", "position_ids", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5]]))
@@ -122,7 +122,7 @@ def test_padding_free_with_completion_mask(self):
result = collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
+ assert set(result.keys()) == {"input_ids", "position_ids", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3, -100, 5]]))
@@ -139,7 +139,7 @@ def test_packing(self):
result = collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
+ assert set(result.keys()) == {"input_ids", "position_ids", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, 6, -100, 8, 9, 10, -100]]))
@@ -151,7 +151,7 @@ def test_pad_to_multiple_of(self):
result = collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
+ assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
@@ -163,7 +163,7 @@ def test_pad_to_multiple_of_and_padding_free(self):
result = collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
+ assert set(result.keys()) == {"input_ids", "position_ids", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, -100, -100, -100]]))
@@ -175,7 +175,7 @@ def test_custom_position_ids_but_no_padding_free(self):
result = self.collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
+ assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
@@ -187,7 +187,7 @@ def test_single_example(self):
result = self.collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
+ assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]]))
@@ -199,7 +199,7 @@ def test_different_pad_token_id(self):
result = collator(examples)
- self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
+ assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
@@ -221,22 +221,22 @@ def test_assistant_masks(self):
def test_single_example_single_doc(self):
batch_seq_lengths = [[5]]
result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths)
- self.assertEqual(len(result), 1)
- self.assertTrue(torch.equal(result[0], torch.arange(5)))
+ assert len(result) == 1
+ assert torch.equal(result[0], torch.arange(5))
def test_single_example_multiple_docs(self):
batch_seq_lengths = [[3, 2]]
result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths)
- self.assertEqual(len(result), 1)
+ assert len(result) == 1
# First sequence: 0, 1, 2; second sequence: 0, 1
- self.assertTrue(torch.equal(result[0], torch.tensor([0, 1, 2, 0, 1])))
+ assert torch.equal(result[0], torch.tensor([0, 1, 2, 0, 1]))
def test_multiple_examples(self):
batch_seq_lengths = [[2, 2], [3]]
result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths)
- self.assertEqual(len(result), 2)
- self.assertTrue(torch.equal(result[0], torch.tensor([0, 1, 0, 1])))
- self.assertTrue(torch.equal(result[1], torch.arange(3)))
+ assert len(result) == 2
+ assert torch.equal(result[0], torch.tensor([0, 1, 0, 1]))
+ assert torch.equal(result[1], torch.arange(3))
class SFTTrainerTester(TrlTestCase):
@@ -262,12 +262,12 @@ def test_train(self, model_id):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
# Special case for harmony
def test_train_gpt_oss(self):
@@ -287,12 +287,12 @@ def test_train_gpt_oss(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_model(self):
# Instantiate the model
@@ -312,12 +312,12 @@ def test_train_model(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_dft_loss(self):
# Get the dataset
@@ -348,12 +348,12 @@ def test_train_dft_loss(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_moe_model_with_aux_loss(self):
# Get the dataset
@@ -375,13 +375,13 @@ def test_train_moe_model_with_aux_loss(self):
trainer.train()
# Check that the training loss and aux loss are not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
- self.assertIsNotNone(trainer.state.log_history[-1]["aux_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ assert trainer.state.log_history[-1]["aux_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_formatting_func(self):
# Dummy formatting function
@@ -408,12 +408,12 @@ def formatting_prompts_func(example):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_model_dtype(self):
# Get the dataset
@@ -437,7 +437,7 @@ def test_train_model_dtype(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
@@ -447,8 +447,8 @@ def test_train_model_dtype(self):
continue
new_param = trainer.model.get_parameter(n)
# Check the torch dtype
- self.assertEqual(new_param.dtype, torch.float16)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert new_param.dtype == torch.float16
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_dense_with_peft_config(self):
@@ -477,15 +477,15 @@ def test_train_dense_with_peft_config(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_moe_with_peft_config(self):
@@ -514,15 +514,15 @@ def test_train_moe_with_peft_config(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_peft_model(self):
@@ -551,15 +551,15 @@ def test_train_peft_model(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_dense_with_peft_config_and_gradient_checkpointing(self):
@@ -588,15 +588,15 @@ def test_train_dense_with_peft_config_and_gradient_checkpointing(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_moe_with_peft_config_and_gradient_checkpointing(self):
@@ -625,15 +625,15 @@ def test_train_moe_with_peft_config_and_gradient_checkpointing(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_peft
def test_train_with_peft_model_and_gradient_checkpointing(self):
@@ -652,7 +652,7 @@ def test_train_with_peft_model_and_gradient_checkpointing(self):
trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset)
# Verify model is a PeftModel
- self.assertIsInstance(trainer.model, PeftModel)
+ assert isinstance(trainer.model, PeftModel)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
@@ -661,15 +661,15 @@ def test_train_with_peft_model_and_gradient_checkpointing(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_liger_kernel
def test_train_with_liger(self):
@@ -689,12 +689,12 @@ def test_train_with_liger(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_non_chatml_conversational_data(self):
# Get the dataset
@@ -719,12 +719,12 @@ def rename_fields(example: list[dict]):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_pretokenized_data(self):
# Get the dataset
@@ -749,12 +749,12 @@ def tokenize_example(example):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_iterable_dataset(self):
# Get the dataset
@@ -773,12 +773,12 @@ def test_train_with_iterable_dataset(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@require_flash_attn
def test_train_padding_free(self):
@@ -804,12 +804,12 @@ def test_train_padding_free(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@parameterized.expand([("bfd",), ("wrapped",)])
@ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning)
@@ -833,12 +833,12 @@ def test_train_packing(self, packing_strategy):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
@ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning)
@ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning)
@@ -863,16 +863,16 @@ def test_eval_packing(self):
# Check the number of sequences in train and eval datasets
num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"])
num_eval_seqs = sum(len(x) for x in trainer.eval_dataset["seq_lengths"])
- self.assertEqual(num_train_seqs, 17) # we should still have 17 seqs
- self.assertEqual(num_eval_seqs, 2) # we should still have 2 seqs
+ assert num_train_seqs == 17 # we should still have 17 seqs
+ assert num_eval_seqs == 2 # we should still have 2 seqs
# Check that all sequences are shorter than the max length
- self.assertTrue(all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]))
- self.assertTrue(all(sum(x) <= 64 for x in trainer.eval_dataset["seq_lengths"]))
+ assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"])
+ assert all(sum(x) <= 64 for x in trainer.eval_dataset["seq_lengths"])
# Check the number of sequences in train and eval datasets
- self.assertEqual(len(trainer.train_dataset["input_ids"]), 3) # w/ this dataset, we end up with 46 seqs
- self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs
+ assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs
+ assert len(trainer.eval_dataset["input_ids"]) == 1 # w/ this dataset, we end up with 6 seqs
@ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning)
@ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning)
@@ -897,17 +897,17 @@ def test_only_train_packing(self):
# Check the number of sequences in train dataset
num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"])
- self.assertEqual(num_train_seqs, 17) # we should still have 17 seqs
+ assert num_train_seqs == 17 # we should still have 17 seqs
# We expect eval dataset not having "seq_lengths" as eval_packing is False
- self.assertNotIn("seq_lengths", trainer.eval_dataset)
+ assert "seq_lengths" not in trainer.eval_dataset
# Check that all sequences are shorter than the max length
- self.assertTrue(all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]))
+ assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"])
# Check the number of sequences in train and eval datasets
- self.assertEqual(len(trainer.train_dataset["input_ids"]), 3) # w/ this dataset, we end up with 46 seqs
- self.assertEqual(len(trainer.eval_dataset["input_ids"]), 2) # w/ this dataset, we end up with 6 seqs
+ assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs
+ assert len(trainer.eval_dataset["input_ids"]) == 2 # w/ this dataset, we end up with 6 seqs
def test_train_with_chat_template_kwargs(self):
# Get the dataset
@@ -934,12 +934,12 @@ def test_train_with_chat_template_kwargs(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_assistant_only(self):
# Get the dataset
@@ -958,12 +958,12 @@ def test_train_assistant_only(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_completion_only(self):
# Get the dataset
@@ -982,12 +982,12 @@ def test_train_completion_only(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_completion_only_harmony(self):
# Get the dataset
@@ -1006,12 +1006,12 @@ def test_train_completion_only_harmony(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_assistant_only_and_completion_only(self):
# Get the dataset
@@ -1040,12 +1040,12 @@ def add_to_completion(example):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_assistant_only_iterable_dataset(self):
# Get the dataset
@@ -1066,12 +1066,12 @@ def test_train_assistant_only_iterable_dataset(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_set_chat_template_from_model(self):
# Get the dataset
@@ -1091,12 +1091,12 @@ def test_train_with_set_chat_template_from_model(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_set_chat_template_from_path(self):
# Get the dataset
@@ -1120,24 +1120,22 @@ def test_train_with_set_chat_template_from_path(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
# Check that the template saved in the output directory is the same as the one used for training
template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja"
- self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}")
+ assert template_path.exists(), f"Chat template not found at {template_path}"
with open(template_path) as f:
template_content = f.read()
with open(training_args.chat_template_path) as f:
original_template_content = f.read()
- self.assertEqual(
- template_content, original_template_content, "Chat template content does not match the original"
- )
+ assert template_content == original_template_content, "Chat template content does not match the original"
def test_train_toolcall_data(self):
# Get the dataset
@@ -1156,12 +1154,12 @@ def test_train_toolcall_data(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_train_with_eval(self):
# Get the dataset
@@ -1180,7 +1178,7 @@ def test_train_with_eval(self):
trainer.train()
# Check that the eval loss is not None
- self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
+ assert trainer.state.log_history[0]["eval_loss"] is not None
def test_train_with_multiple_eval_dataset(self):
# Get the dataset
@@ -1198,8 +1196,8 @@ def test_train_with_multiple_eval_dataset(self):
trainer.train()
# Check that the eval losses are not None
- self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"])
- self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"])
+ assert trainer.state.log_history[-3]["eval_data1_loss"] is not None
+ assert trainer.state.log_history[-2]["eval_data2_loss"] is not None
def test_train_with_gradient_checkpointing(self):
# Get the dataset
@@ -1218,12 +1216,12 @@ def test_train_with_gradient_checkpointing(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
def test_tag_added(self):
# Get the dataset
@@ -1236,7 +1234,7 @@ def test_tag_added(self):
)
for tag in ["sft", "trl"]:
- self.assertIn(tag, trainer.model.model_tags)
+ assert tag in trainer.model.model_tags
@require_peft
def test_tag_added_peft(self):
@@ -1251,7 +1249,7 @@ def test_tag_added_peft(self):
)
for tag in ["sft", "trl"]:
- self.assertIn(tag, trainer.model.model_tags)
+ assert tag in trainer.model.model_tags
@parameterized.expand(
[
@@ -1285,7 +1283,7 @@ def test_train_vlm(self, model_id):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
@@ -1302,9 +1300,7 @@ def test_train_vlm(self, model_id):
):
# fmt: on
continue
- self.assertFalse(
- torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
- )
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
@require_vision
def test_train_vlm_prompt_completion(self):
@@ -1330,12 +1326,12 @@ def test_train_vlm_prompt_completion(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
@@ -1363,7 +1359,7 @@ def test_train_vlm_gemma_3n(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
@@ -1371,7 +1367,7 @@ def test_train_vlm_gemma_3n(self):
if "model.vision_tower" in n:
# The vision tower is not updated, not sure why at this point.
continue
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
@require_vision
def test_train_vlm_text_only_data(self):
@@ -1393,15 +1389,15 @@ def test_train_vlm_text_only_data(self):
trainer.train()
# Check that the training loss is not None
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n.startswith("model.visual"):
- self.assertTrue(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated")
+ assert torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated"
else:
- self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
+ assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
@require_peft
def test_prompt_tuning(self):
@@ -1422,16 +1418,16 @@ def test_prompt_tuning(self):
trainer.train()
# Check that training completed successfully
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
- self.assertIsNotNone(trainer.state.log_history[-1]["mean_token_accuracy"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "base_model" in n: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "prompt_encoder" in n: # We expect the peft parameters to be different
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
else:
raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}")
@@ -1455,7 +1451,7 @@ def test_peft_model_with_quantization(self):
# Verify that this triggers the is_qlora condition
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
- self.assertTrue(is_qlora, "Model should be detected as QLoRA (quantized)")
+ assert is_qlora, "Model should be detected as QLoRA (quantized)"
# Create LoRA configuration suitable for QLoRA
lora_config = LoraConfig(
@@ -1470,7 +1466,7 @@ def test_peft_model_with_quantization(self):
peft_model = get_peft_model(model, lora_config)
# Verify the quantization attributes are preserved on the PeftModel
- self.assertTrue(getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag")
+ assert getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag"
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
@@ -1489,9 +1485,9 @@ def test_peft_model_with_quantization(self):
base_params_before.append(name)
# Ensure we have the expected parameter distribution for QLoRA
- self.assertTrue(len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially")
- self.assertTrue(len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters")
- self.assertEqual(len(base_params_before), 0, "Base model parameters should already be frozen in PeftModel")
+ assert len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially"
+ assert len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters"
+ assert len(base_params_before) == 0, "Base model parameters should already be frozen in PeftModel"
# Initialize the trainer with the already configured PeftModel
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", max_steps=1)
@@ -1508,32 +1504,24 @@ def test_peft_model_with_quantization(self):
lora_params_after.append(name)
# LoRA parameters should remain trainable
- self.assertTrue(
- len(trainable_params_after) > 0,
- f"PeftModel should still have trainable parameters after SFTTrainer initialization. "
- f"Found {len(trainable_params_after)} trainable params. "
- f"This test fails without the fix for issue #3926.",
- )
+ assert len(trainable_params_after) > 0, \
+ f"PeftModel should still have trainable parameters after SFTTrainer initialization. " \
+ f"Found {len(trainable_params_after)} trainable params. " \
+ f"This test fails without the fix for issue #3926."
- self.assertTrue(
- len(lora_params_after) > 0,
- f"LoRA adapter parameters should remain trainable. "
- f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original.",
- )
+ assert len(lora_params_after) > 0, \
+ f"LoRA adapter parameters should remain trainable. " \
+ f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original."
# Ensure the parameter counts are preserved (no additional freezing occurred)
- self.assertEqual(
- len(trainable_params_before),
- len(trainable_params_after),
- "Number of trainable parameters should not change after SFTTrainer initialization",
- )
+ assert len(trainable_params_before) == \
+ len(trainable_params_after), \
+ "Number of trainable parameters should not change after SFTTrainer initialization"
# Verify that all original LoRA parameters are still trainable
- self.assertEqual(
- set(lora_params_before),
- set(lora_params_after),
- "All original LoRA parameters should remain trainable after SFTTrainer initialization",
- )
+ assert set(lora_params_before) == \
+ set(lora_params_after), \
+ "All original LoRA parameters should remain trainable after SFTTrainer initialization"
@require_peft
def test_prompt_tuning_peft_model(self):
@@ -1552,15 +1540,15 @@ def test_prompt_tuning_peft_model(self):
trainer.train()
# Check that training completed successfully
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
- self.assertIsNotNone(trainer.state.log_history[-1]["mean_token_accuracy"])
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "base_model" in n: # We expect the base model parameters to be the same
- self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+ assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "prompt_encoder" in n: # We expect the peft parameters to be different
- self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+ assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
else:
raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}")
diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py
index 61ab72130f2..35a8da57cd3 100644
--- a/tests/test_trainers_args.py
+++ b/tests/test_trainers_args.py
@@ -76,22 +76,22 @@ def test_bco(self):
train_dataset=dataset,
processing_class=tokenizer,
)
- self.assertEqual(trainer.args.max_length, 256)
- self.assertEqual(trainer.args.max_prompt_length, 64)
- self.assertEqual(trainer.args.max_completion_length, 64)
- self.assertEqual(trainer.args.beta, 0.5)
- self.assertEqual(trainer.args.label_pad_token_id, -99)
- self.assertEqual(trainer.args.padding_value, -99)
- self.assertEqual(trainer.args.truncation_mode, "keep_start")
+ assert trainer.args.max_length == 256
+ assert trainer.args.max_prompt_length == 64
+ assert trainer.args.max_completion_length == 64
+ assert trainer.args.beta == 0.5
+ assert trainer.args.label_pad_token_id == -99
+ assert trainer.args.padding_value == -99
+ assert trainer.args.truncation_mode == "keep_start"
# self.assertEqual(trainer.args.generate_during_eval, True)
- self.assertEqual(trainer.args.is_encoder_decoder, True)
- self.assertEqual(trainer.args.precompute_ref_log_probs, True)
- self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
- self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
- self.assertEqual(trainer.args.dataset_num_proc, 4)
- self.assertEqual(trainer.args.prompt_sample_size, 512)
- self.assertEqual(trainer.args.min_density_ratio, 0.2)
- self.assertEqual(trainer.args.max_density_ratio, 20.0)
+ assert trainer.args.is_encoder_decoder == True
+ assert trainer.args.precompute_ref_log_probs == True
+ assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
+ assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
+ assert trainer.args.dataset_num_proc == 4
+ assert trainer.args.prompt_sample_size == 512
+ assert trainer.args.min_density_ratio == 0.2
+ assert trainer.args.max_density_ratio == 20.0
def test_cpo(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -117,22 +117,22 @@ def test_cpo(self):
dataset_num_proc=4,
)
trainer = CPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer)
- self.assertEqual(trainer.args.max_length, 256)
- self.assertEqual(trainer.args.max_prompt_length, 64)
- self.assertEqual(trainer.args.max_completion_length, 64)
- self.assertEqual(trainer.args.beta, 0.5)
- self.assertEqual(trainer.args.label_smoothing, 0.5)
- self.assertEqual(trainer.args.loss_type, "hinge")
- self.assertEqual(trainer.args.disable_dropout, False)
- self.assertEqual(trainer.args.cpo_alpha, 0.5)
- self.assertEqual(trainer.args.simpo_gamma, 0.2)
- self.assertEqual(trainer.args.label_pad_token_id, -99)
- self.assertEqual(trainer.args.padding_value, -99)
- self.assertEqual(trainer.args.truncation_mode, "keep_start")
+ assert trainer.args.max_length == 256
+ assert trainer.args.max_prompt_length == 64
+ assert trainer.args.max_completion_length == 64
+ assert trainer.args.beta == 0.5
+ assert trainer.args.label_smoothing == 0.5
+ assert trainer.args.loss_type == "hinge"
+ assert trainer.args.disable_dropout == False
+ assert trainer.args.cpo_alpha == 0.5
+ assert trainer.args.simpo_gamma == 0.2
+ assert trainer.args.label_pad_token_id == -99
+ assert trainer.args.padding_value == -99
+ assert trainer.args.truncation_mode == "keep_start"
# self.assertEqual(trainer.args.generate_during_eval, True)
- self.assertEqual(trainer.args.is_encoder_decoder, True)
- self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
- self.assertEqual(trainer.args.dataset_num_proc, 4)
+ assert trainer.args.is_encoder_decoder == True
+ assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
+ assert trainer.args.dataset_num_proc == 4
def test_dpo(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -174,32 +174,32 @@ def test_dpo(self):
train_dataset=dataset,
processing_class=tokenizer,
)
- self.assertEqual(trainer.args.beta, 0.5)
- self.assertEqual(trainer.args.label_smoothing, 0.5)
- self.assertEqual(trainer.args.loss_type, "hinge")
- self.assertEqual(trainer.args.label_pad_token_id, -99)
- self.assertEqual(trainer.args.pad_token, ".")
- self.assertEqual(trainer.args.truncation_mode, "keep_start")
- self.assertEqual(trainer.args.max_length, 256)
- self.assertEqual(trainer.args.max_prompt_length, 64)
- self.assertEqual(trainer.args.max_completion_length, 64)
- self.assertEqual(trainer.args.disable_dropout, False)
+ assert trainer.args.beta == 0.5
+ assert trainer.args.label_smoothing == 0.5
+ assert trainer.args.loss_type == "hinge"
+ assert trainer.args.label_pad_token_id == -99
+ assert trainer.args.pad_token == "."
+ assert trainer.args.truncation_mode == "keep_start"
+ assert trainer.args.max_length == 256
+ assert trainer.args.max_prompt_length == 64
+ assert trainer.args.max_completion_length == 64
+ assert trainer.args.disable_dropout == False
# self.assertEqual(trainer.args.generate_during_eval, True)
- self.assertEqual(trainer.args.precompute_ref_log_probs, True)
- self.assertEqual(trainer.args.dataset_num_proc, 4)
- self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
- self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
- self.assertEqual(trainer.args.model_adapter_name, "dummy_adapter")
- self.assertEqual(trainer.args.ref_adapter_name, "dummy_adapter")
- self.assertEqual(trainer.args.reference_free, True)
- self.assertEqual(trainer.args.force_use_ref_model, True)
- self.assertEqual(trainer.args.f_divergence_type, FDivergenceType.JS_DIVERGENCE)
- self.assertEqual(trainer.args.f_alpha_divergence_coef, 0.5)
+ assert trainer.args.precompute_ref_log_probs == True
+ assert trainer.args.dataset_num_proc == 4
+ assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
+ assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
+ assert trainer.args.model_adapter_name == "dummy_adapter"
+ assert trainer.args.ref_adapter_name == "dummy_adapter"
+ assert trainer.args.reference_free == True
+ assert trainer.args.force_use_ref_model == True
+ assert trainer.args.f_divergence_type == FDivergenceType.JS_DIVERGENCE
+ assert trainer.args.f_alpha_divergence_coef == 0.5
# self.assertEqual(trainer.args.sync_ref_model, True)
- self.assertEqual(trainer.args.ref_model_mixup_alpha, 0.5)
- self.assertEqual(trainer.args.ref_model_sync_steps, 32)
- self.assertEqual(trainer.args.rpo_alpha, 0.5)
- self.assertEqual(trainer.args.discopop_tau, 0.1)
+ assert trainer.args.ref_model_mixup_alpha == 0.5
+ assert trainer.args.ref_model_sync_steps == 32
+ assert trainer.args.rpo_alpha == 0.5
+ assert trainer.args.discopop_tau == 0.1
def test_kto(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -230,21 +230,21 @@ def test_kto(self):
train_dataset=dataset,
processing_class=tokenizer,
)
- self.assertEqual(trainer.args.max_length, 256)
- self.assertEqual(trainer.args.max_prompt_length, 64)
- self.assertEqual(trainer.args.max_completion_length, 64)
- self.assertEqual(trainer.args.beta, 0.5)
- self.assertEqual(trainer.args.desirable_weight, 0.5)
- self.assertEqual(trainer.args.undesirable_weight, 0.5)
- self.assertEqual(trainer.args.label_pad_token_id, -99)
- self.assertEqual(trainer.args.padding_value, -99)
- self.assertEqual(trainer.args.truncation_mode, "keep_start")
+ assert trainer.args.max_length == 256
+ assert trainer.args.max_prompt_length == 64
+ assert trainer.args.max_completion_length == 64
+ assert trainer.args.beta == 0.5
+ assert trainer.args.desirable_weight == 0.5
+ assert trainer.args.undesirable_weight == 0.5
+ assert trainer.args.label_pad_token_id == -99
+ assert trainer.args.padding_value == -99
+ assert trainer.args.truncation_mode == "keep_start"
# self.assertEqual(trainer.args.generate_during_eval, True)
- self.assertEqual(trainer.args.is_encoder_decoder, True)
- self.assertEqual(trainer.args.precompute_ref_log_probs, True)
- self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
- self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
- self.assertEqual(trainer.args.dataset_num_proc, 4)
+ assert trainer.args.is_encoder_decoder == True
+ assert trainer.args.precompute_ref_log_probs == True
+ assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
+ assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
+ assert trainer.args.dataset_num_proc == 4
@parameterized.expand([(False,), (True,)])
def test_nash_md(self, mixtures_coef_list):
@@ -266,7 +266,7 @@ def test_nash_md(self, mixtures_coef_list):
reward_funcs=reward_model,
train_dataset=dataset,
)
- self.assertEqual(trainer.args.mixture_coef, 0.5 if not mixtures_coef_list else [0.5, 0.6])
+ assert trainer.args.mixture_coef == (0.5 if not mixtures_coef_list else [0.5, 0.6])
@parameterized.expand([(False,), (True,)])
def test_online_dpo(self, beta_list):
@@ -293,11 +293,11 @@ def test_online_dpo(self, beta_list):
processing_class=tokenizer,
reward_processing_classes=tokenizer,
)
- self.assertEqual(trainer.args.max_new_tokens, 42)
- self.assertEqual(trainer.args.temperature, 0.5)
- self.assertEqual(trainer.args.missing_eos_penalty, 0.33)
- self.assertEqual(trainer.args.beta, 0.6 if not beta_list else [0.6, 0.7])
- self.assertEqual(trainer.args.loss_type, "hinge")
+ assert trainer.args.max_new_tokens == 42
+ assert trainer.args.temperature == 0.5
+ assert trainer.args.missing_eos_penalty == 0.33
+ assert trainer.args.beta == (0.6 if not beta_list else [0.6, 0.7])
+ assert trainer.args.loss_type == "hinge"
def test_orpo(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -319,12 +319,12 @@ def test_orpo(self):
dataset_num_proc=4,
)
trainer = ORPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer)
- self.assertEqual(trainer.args.max_length, 256)
- self.assertEqual(trainer.args.max_prompt_length, 64)
- self.assertEqual(trainer.args.max_completion_length, 64)
- self.assertEqual(trainer.args.beta, 0.5)
- self.assertEqual(trainer.args.disable_dropout, False)
- self.assertEqual(trainer.args.label_pad_token_id, -99)
+ assert trainer.args.max_length == 256
+ assert trainer.args.max_prompt_length == 64
+ assert trainer.args.max_completion_length == 64
+ assert trainer.args.beta == 0.5
+ assert trainer.args.disable_dropout == False
+ assert trainer.args.label_pad_token_id == -99
def test_reward(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -343,9 +343,9 @@ def test_reward(self):
train_dataset=dataset,
processing_class=tokenizer,
)
- self.assertEqual(trainer.args.max_length, 256)
- self.assertEqual(trainer.args.dataset_num_proc, 4)
- self.assertEqual(trainer.args.center_rewards_coefficient, 0.1)
+ assert trainer.args.max_length == 256
+ assert trainer.args.dataset_num_proc == 4
+ assert trainer.args.center_rewards_coefficient == 0.1
def test_sft(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -362,15 +362,15 @@ def test_sft(self):
eval_packing=True,
)
trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset)
- self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field")
- self.assertEqual(trainer.args.packing, True)
- self.assertEqual(trainer.args.max_length, 256)
- self.assertEqual(trainer.args.dataset_num_proc, 4)
- self.assertEqual(trainer.args.neftune_noise_alpha, 0.1)
- self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
- self.assertIn("append_concat_token", trainer.args.dataset_kwargs)
- self.assertEqual(trainer.args.dataset_kwargs["append_concat_token"], True)
- self.assertEqual(trainer.args.eval_packing, True)
+ assert trainer.args.dataset_text_field == "dummy_text_field"
+ assert trainer.args.packing == True
+ assert trainer.args.max_length == 256
+ assert trainer.args.dataset_num_proc == 4
+ assert trainer.args.neftune_noise_alpha == 0.1
+ assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
+ assert "append_concat_token" in trainer.args.dataset_kwargs
+ assert trainer.args.dataset_kwargs["append_concat_token"] == True
+ assert trainer.args.eval_packing == True
@parameterized.expand([(False,), (True,)])
def test_xpo(self, alpha_list):
@@ -392,4 +392,4 @@ def test_xpo(self, alpha_list):
reward_funcs=reward_model,
train_dataset=dataset,
)
- self.assertEqual(trainer.args.alpha, 0.5 if not alpha_list else [0.5, 0.6])
+ assert trainer.args.alpha == (0.5 if not alpha_list else [0.5, 0.6])
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 4758109e08b..2bfa9bff0c2 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -47,6 +47,7 @@
)
from .testing_utils import TrlTestCase, require_rich
+import pytest
if is_peft_available():
@@ -59,14 +60,14 @@ def test_pad_1_dim_left(self):
y = torch.tensor([4, 5])
output = pad((x, y), padding_value=0, padding_side="left")
expected = torch.tensor([[1, 2, 3], [0, 4, 5]])
- self.assertTrue(torch.equal(output, expected))
+ assert torch.equal(output, expected)
def test_pad_1_dim_right(self):
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5])
output = pad((x, y), padding_value=0, padding_side="right")
expected = torch.tensor([[1, 2, 3], [4, 5, 0]])
- self.assertTrue(torch.equal(output, expected))
+ assert torch.equal(output, expected)
def test_pad_2_dim_left(self):
x = torch.tensor([[1, 2], [3, 4]])
@@ -78,7 +79,7 @@ def test_pad_2_dim_left(self):
[[0, 0], [5, 6]],
]
)
- self.assertTrue(torch.equal(output, expected))
+ assert torch.equal(output, expected)
def test_pad_2_dim_right(self):
x = torch.tensor([[1, 2], [3, 4]])
@@ -90,7 +91,7 @@ def test_pad_2_dim_right(self):
[[5, 6], [0, 0]],
]
)
- self.assertTrue(torch.equal(output, expected))
+ assert torch.equal(output, expected)
def test_pad_2_dim_right_multidim(self):
x = torch.tensor([[1, 2], [3, 4]])
@@ -102,7 +103,7 @@ def test_pad_2_dim_right_multidim(self):
[[5, 0], [0, 0]],
]
)
- self.assertTrue(torch.equal(output, expected))
+ assert torch.equal(output, expected)
def test_pad_to_multiple_of_1(self):
x = torch.tensor([1, 2, 3])
@@ -110,7 +111,7 @@ def test_pad_to_multiple_of_1(self):
# Max length is 3, pad to multiple of 4
output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
expected = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])
- self.assertTrue(torch.equal(output, expected))
+ assert torch.equal(output, expected)
def test_pad_to_multiple_of_2(self):
x = torch.tensor([1, 2, 3, 4, 5])
@@ -118,7 +119,7 @@ def test_pad_to_multiple_of_2(self):
# Max length is 3, pad to multiple of 4
output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
expected = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]])
- self.assertTrue(torch.equal(output, expected))
+ assert torch.equal(output, expected)
def test_pad_to_multiple_of_side_left(self):
x = torch.tensor([1, 2, 3, 4, 5])
@@ -126,7 +127,7 @@ def test_pad_to_multiple_of_side_left(self):
# Max length is 3, pad to multiple of 4
output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
expected = torch.tensor([[0, 0, 0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 6, 7, 8]])
- self.assertTrue(torch.equal(output, expected))
+ assert torch.equal(output, expected)
def test_pad_to_multiple_of_no_extra_padding(self):
x = torch.tensor([1, 2, 3, 4])
@@ -134,7 +135,7 @@ def test_pad_to_multiple_of_no_extra_padding(self):
# Already multiple of 4
output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
- self.assertTrue(torch.equal(output, expected))
+ assert torch.equal(output, expected)
@require_peft
@@ -143,7 +144,7 @@ def test_create_peft_config_use_peft_false(self):
"""Test that when use_peft is False, the function returns None."""
model_args = ModelConfig(use_peft=False)
peft_config = get_peft_config(model_args)
- self.assertIsNone(peft_config)
+ assert peft_config is None
def test_create_peft_config_use_peft_true(self):
"""Test that when use_peft is True, the function returns a LoraConfig object."""
@@ -159,7 +160,7 @@ def test_create_peft_config_use_peft_true(self):
}
model_args = ModelConfig(use_peft=True, **peft_kwargs)
peft_config = get_peft_config(model_args)
- self.assertTrue(isinstance(peft_config, LoraConfig))
+ assert isinstance(peft_config, LoraConfig)
for arg, value in peft_kwargs.items():
# Test that lists of modules are converted to sets
if arg == "lora_target_modules":
@@ -168,7 +169,7 @@ def test_create_peft_config_use_peft_true(self):
if arg in ["lora_r", "lora_task_type", "lora_target_modules", "lora_modules_to_save"]:
arg = arg[len("lora_") :] if arg.startswith("lora_") else arg
- self.assertEqual(getattr(peft_config, arg), value)
+ assert getattr(peft_config, arg) == value
class TestDecodeAndStripPadding(TrlTestCase):
@@ -179,12 +180,12 @@ def setUp(self):
def test_example_with_padding(self):
inputs = self.tokenizer(["Hello world", "Hello"], padding=True, return_tensors="pt")
decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer)
- self.assertEqual(decoded, ["Hello world", "Hello"])
+ assert decoded == ["Hello world", "Hello"]
def test_example_without_padding(self):
inputs = self.tokenizer(["Hello", "Hello"], padding=False, return_tensors="pt")
decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer)
- self.assertEqual(decoded, ["Hello", "Hello"])
+ assert decoded == ["Hello", "Hello"]
class TestGenerateModelCard(TrlTestCase):
@@ -203,15 +204,15 @@ def test_full(self):
paper_id="1234.56789",
)
card_text = str(model_card)
- self.assertIn("[username/my_base_model](https://huggingface.co/username/my_base_model)", card_text)
- self.assertIn("my_model", card_text)
- self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text)
- self.assertIn("datasets: username/my_dataset", card_text)
- self.assertIn("](https://wandb.ai/username/project_id/runs/abcd1234)", card_text)
- self.assertIn("](https://www.comet.com/username/project_id/experiment_id", card_text)
- self.assertIn("My Trainer", card_text)
- self.assertIn("```bibtex\n@article{my_trainer, ...}\n```", card_text)
- self.assertIn("[My Paper](https://huggingface.co/papers/1234.56789)", card_text)
+ assert "[username/my_base_model](https://huggingface.co/username/my_base_model)" in card_text
+ assert "my_model" in card_text
+ assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
+ assert "datasets: username/my_dataset" in card_text
+ assert "](https://wandb.ai/username/project_id/runs/abcd1234)" in card_text
+ assert "](https://www.comet.com/username/project_id/experiment_id" in card_text
+ assert "My Trainer" in card_text
+ assert "```bibtex\n@article{my_trainer, ...}\n```" in card_text
+ assert "[My Paper](https://huggingface.co/papers/1234.56789)" in card_text
def test_val_none(self):
model_card = generate_model_card(
@@ -228,9 +229,9 @@ def test_val_none(self):
paper_id=None,
)
card_text = str(model_card)
- self.assertIn("my_model", card_text)
- self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text)
- self.assertIn("My Trainer", card_text)
+ assert "my_model" in card_text
+ assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
+ assert "My Trainer" in card_text
class TestDataCollatorForChatML(TrlTestCase):
@@ -265,11 +266,11 @@ def test_data_collator_for_chatml(self):
data = self.collator(self.examples)
# Verify basic shapes and types
- self.assertIn("input_ids", data)
- self.assertIn("attention_mask", data)
- self.assertIn("labels", data)
- self.assertIn("prompts", data)
- self.assertIn("prompt_attention_mask", data)
+ assert "input_ids" in data
+ assert "attention_mask" in data
+ assert "labels" in data
+ assert "prompts" in data
+ assert "prompt_attention_mask" in data
# Decode input_ids and labels for verification
input_ids = data["input_ids"][0].tolist()
@@ -278,23 +279,21 @@ def test_data_collator_for_chatml(self):
# Get the last assistant's response for comparison
last_message = self.examples[0][self.messages_key][-1]
- self.assertEqual(last_message["role"], "assistant", "Last message should be from assistant")
+ assert last_message["role"] == "assistant", "Last message should be from assistant"
last_assistant_response = last_message["content"]
# Verify that input_ids contain both prompt and response
decoded_input = self.tokenizer.decode(input_ids)
- self.assertIn(last_assistant_response, decoded_input, "Input should contain assistant's response")
+ assert last_assistant_response in decoded_input, "Input should contain assistant's response"
# Verify that prompts only contain the conversation up to the last response
decoded_prompt = self.tokenizer.decode(prompt_only)
- self.assertNotIn(last_assistant_response, decoded_prompt, "Prompt should not contain assistant's response")
+ assert last_assistant_response not in decoded_prompt, "Prompt should not contain assistant's response"
# Verify labels are -100 for non-assistant parts
prompt_length = len(prompt_only)
- self.assertTrue(
- all(label == self.ignore_index for label in labels[:prompt_length]),
- "Labels should be ignore_index for prompt tokens",
- )
+ assert all(label == self.ignore_index for label in labels[:prompt_length]), \
+ "Labels should be ignore_index for prompt tokens"
# Verify labels match assistant response after prompt
# Add a filter to remove any trailing tokens after the first <|im_end|>
@@ -310,24 +309,18 @@ def test_data_collator_for_chatml(self):
response_labels.append(label)
if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"):
break
- self.assertEqual(
- response_labels,
- last_assistant_response_tokens,
- "Labels should match assistant response tokens",
- )
+ assert response_labels == \
+ last_assistant_response_tokens, \
+ "Labels should match assistant response tokens"
# Verify there isn't a generation prompt at the end
generation_prompt = "<|im_start|>assistant"
- self.assertFalse(
- decoded_input.strip().endswith(generation_prompt),
- f"Input should not end with generation prompt '{generation_prompt}'",
- )
+ assert not decoded_input.strip().endswith(generation_prompt), \
+ f"Input should not end with generation prompt '{generation_prompt}'"
- self.assertEqual(
- response_labels,
- last_assistant_response_tokens,
- "Labels should match assistant response tokens",
- )
+ assert response_labels == \
+ last_assistant_response_tokens, \
+ "Labels should match assistant response tokens"
class TestBatchGeneration(TrlTestCase):
@@ -367,9 +360,9 @@ def test_mini_batch_generation(self):
max_length_query = query_responses.shape[1]
max_length_logits = max_length_query - context_length
- self.assertGreater(max_length_query, context_length)
- self.assertEqual(query_responses.shape, (bs, max_length_query))
- self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size))
+ assert max_length_query > context_length
+ assert query_responses.shape == (bs, max_length_query)
+ assert logits.shape == (bs, max_length_logits, self.model.config.vocab_size)
def test_single_batch_generation(self):
batch = [
@@ -386,9 +379,9 @@ def test_single_batch_generation(self):
max_length_query = query_responses.shape[1]
max_length_logits = max_length_query - context_length
- self.assertGreater(max_length_query, context_length)
- self.assertEqual(query_responses.shape, (bs, max_length_query))
- self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size))
+ assert max_length_query > context_length
+ assert query_responses.shape == (bs, max_length_query)
+ assert logits.shape == (bs, max_length_logits, self.model.config.vocab_size)
class TestComputeAccuracy(TrlTestCase):
@@ -404,7 +397,7 @@ def test_token_classification_task(self):
)
expected_accuracy = 0.5 # 2 matches, 2 mismatches
result = compute_accuracy(eval_pred)
- self.assertAlmostEqual(result["accuracy"], expected_accuracy)
+ assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0
def test_token_classification_task_with_ignored_tokens_0(self):
eval_pred = (
@@ -418,7 +411,7 @@ def test_token_classification_task_with_ignored_tokens_0(self):
)
expected_accuracy = 1.0 # All non-ignored tokens match
result = compute_accuracy(eval_pred)
- self.assertAlmostEqual(result["accuracy"], expected_accuracy)
+ assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0
def test_token_classification_task_with_ignored_tokens_1(self):
eval_pred = (
@@ -432,7 +425,7 @@ def test_token_classification_task_with_ignored_tokens_1(self):
)
expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored
result = compute_accuracy(eval_pred)
- self.assertAlmostEqual(result["accuracy"], expected_accuracy)
+ assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0
def test_rewards_comparison_task(self):
eval_pred = (
@@ -450,12 +443,12 @@ def test_rewards_comparison_task(self):
with self.assertLogs("trl.trainer.utils", level="WARNING") as cm:
result = compute_accuracy(eval_pred)
- self.assertAlmostEqual(result["accuracy"], expected_accuracy)
+ assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0
expected_warning = (
"There are 1 out of 3 instances where the predictions for both options are equal. "
"These instances are ignored in the accuracy computation."
)
- self.assertIn(expected_warning, cm.output[0])
+ assert expected_warning in cm.output[0]
class TestFlushLeft(TrlTestCase):
@@ -469,9 +462,9 @@ def test_basic_case(self):
expected_tensor1 = torch.tensor([[2, 3, 4], [5, 6, 0]])
expected_tensor2 = torch.tensor([[7, 8, 9], [10, 11, 0]])
- self.assertTrue(torch.equal(new_mask, expected_mask))
- self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
- self.assertTrue(torch.equal(new_tensor2, expected_tensor2))
+ assert torch.equal(new_mask, expected_mask)
+ assert torch.equal(new_tensor1, expected_tensor1)
+ assert torch.equal(new_tensor2, expected_tensor2)
def test_single_row(self):
mask = torch.tensor([[0, 0, 1, 1]])
@@ -481,8 +474,8 @@ def test_single_row(self):
expected_mask = torch.tensor([[1, 1]])
expected_tensor1 = torch.tensor([[2, 3]])
- self.assertTrue(torch.equal(new_mask, expected_mask))
- self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
+ assert torch.equal(new_mask, expected_mask)
+ assert torch.equal(new_tensor1, expected_tensor1)
def test_no_shift_needed(self):
mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]])
@@ -492,14 +485,14 @@ def test_no_shift_needed(self):
expected_mask = torch.tensor([[1, 1], [1, 0]])
expected_tensor1 = torch.tensor([[5, 6], [7, 0]])
- self.assertTrue(torch.equal(new_mask, expected_mask))
- self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
+ assert torch.equal(new_mask, expected_mask)
+ assert torch.equal(new_tensor1, expected_tensor1)
def test_no_tensors(self):
mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]])
new_mask = flush_left(mask)
expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
- self.assertTrue(torch.equal(new_mask, expected_mask))
+ assert torch.equal(new_mask, expected_mask)
class TestFlushRight(TrlTestCase):
@@ -513,9 +506,9 @@ def test_basic_case(self):
expected_tensor1 = torch.tensor([[2, 3, 4], [0, 5, 6]])
expected_tensor2 = torch.tensor([[7, 8, 9], [0, 10, 11]])
- self.assertTrue(torch.equal(new_mask, expected_mask))
- self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
- self.assertTrue(torch.equal(new_tensor2, expected_tensor2))
+ assert torch.equal(new_mask, expected_mask)
+ assert torch.equal(new_tensor1, expected_tensor1)
+ assert torch.equal(new_tensor2, expected_tensor2)
def test_single_row(self):
mask = torch.tensor([[1, 1, 0, 0]])
@@ -525,8 +518,8 @@ def test_single_row(self):
expected_mask = torch.tensor([[1, 1]])
expected_tensor1 = torch.tensor([[2, 3]])
- self.assertTrue(torch.equal(new_mask, expected_mask))
- self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
+ assert torch.equal(new_mask, expected_mask)
+ assert torch.equal(new_tensor1, expected_tensor1)
def test_no_shift_needed(self):
mask = torch.tensor([[0, 0, 1, 1], [0, 0, 0, 1]])
@@ -536,14 +529,14 @@ def test_no_shift_needed(self):
expected_mask = torch.tensor([[1, 1], [0, 1]])
expected_tensor1 = torch.tensor([[5, 6], [0, 7]])
- self.assertTrue(torch.equal(new_mask, expected_mask))
- self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
+ assert torch.equal(new_mask, expected_mask)
+ assert torch.equal(new_tensor1, expected_tensor1)
def test_no_tensors(self):
mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]])
new_mask = flush_right(mask)
expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
- self.assertTrue(torch.equal(new_mask, expected_mask))
+ assert torch.equal(new_mask, expected_mask)
class RepeatRandomSamplerTester(TrlTestCase):
@@ -564,7 +557,7 @@ def test_sampler_no_shuffle(self):
sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False)
sampled = list(sampler)
expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]
- self.assertEqual(sampled, expected)
+ assert sampled == expected
def test_sampler_no_repeat(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
@@ -706,7 +699,7 @@ def test_print_output(self, mock_stdout):
╰──────────────────────────────────────────────────────────────────╯
""")
- self.assertEqual(output, expected_output)
+ assert output == expected_output
@patch("sys.stdout", new_callable=StringIO)
def test_num_samples(self, mock_stdout):
@@ -741,7 +734,7 @@ def test_num_samples(self, mock_stdout):
╰─────────────────────────────────────────────╯
"""),
]
- self.assertIn(output, possible_outputs)
+ assert output in possible_outputs
@patch("sys.stdout", new_callable=StringIO)
def test_print_messages(self, mock_stdout):
@@ -790,7 +783,7 @@ def test_print_messages(self, mock_stdout):
╰──────────────────────────────────────────────────────────────────────────────╯
""")
- self.assertEqual(output, expected_output)
+ assert output == expected_output
@patch("sys.stdout", new_callable=StringIO)
def test_print_messages_with_tools(self, mock_stdout):
@@ -829,7 +822,7 @@ def test_print_messages_with_tools(self, mock_stdout):
╰──────────────────────────────────────────────────────────────────────────────╯
""")
- self.assertEqual(output, expected_output)
+ assert output == expected_output
class TestSelectiveLogSoftmax(TrlTestCase):
@@ -848,7 +841,7 @@ def test_selective_log_softmax(self, dtype):
if dtype in [torch.float16, torch.bfloat16]:
# half-precision dtypes fall back to an exact method
- self.assertTrue(torch.equal(actual_output, expected_output))
+ assert torch.equal(actual_output, expected_output)
else:
torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)
@@ -861,8 +854,8 @@ def test_shuffle_preserves_shape(self):
shuffled = shuffle_sequence_dict(tensor_dict)
- self.assertEqual(shuffled["x"].shape, x.shape)
- self.assertEqual(shuffled["y"].shape, y.shape)
+ assert shuffled["x"].shape == x.shape
+ assert shuffled["y"].shape == y.shape
def test_shuffle_consistent_across_tensors(self):
# Use known patterns to check alignment
@@ -878,13 +871,13 @@ def test_shuffle_consistent_across_tensors(self):
y_val = shuffled["y"][i].item()
if torch.equal(x_row, torch.tensor([10, 11])):
- self.assertEqual(y_val, 1)
+ assert y_val == 1
elif torch.equal(x_row, torch.tensor([20, 21])):
- self.assertEqual(y_val, 2)
+ assert y_val == 2
elif torch.equal(x_row, torch.tensor([30, 31])):
- self.assertEqual(y_val, 3)
+ assert y_val == 3
else:
- self.fail("Unexpected x row in shuffled output.")
+ pytest.fail("Unexpected x row in shuffled output.")
def test_none_tensor_remains_none(self):
x = torch.arange(6).reshape(3, 2)
@@ -892,8 +885,8 @@ def test_none_tensor_remains_none(self):
shuffled = shuffle_sequence_dict(tensor_dict)
- self.assertIsNone(shuffled["y"])
- self.assertEqual(shuffled["x"].shape, x.shape)
+ assert shuffled["y"] is None
+ assert shuffled["x"].shape == x.shape
def test_shuffle_with_list(self):
x = torch.tensor([[10, 11], [20, 21], [30, 31]])
@@ -909,13 +902,13 @@ def test_shuffle_with_list(self):
y_val = shuffled["y"][i]
if torch.equal(x_row, torch.tensor([10, 11])):
- self.assertEqual(y_val, "a")
+ assert y_val == "a"
elif torch.equal(x_row, torch.tensor([20, 21])):
- self.assertEqual(y_val, "b")
+ assert y_val == "b"
elif torch.equal(x_row, torch.tensor([30, 31])):
- self.assertEqual(y_val, "c")
+ assert y_val == "c"
else:
- self.fail("Unexpected x row in shuffled output.")
+ pytest.fail("Unexpected x row in shuffled output.")
class SplitTensorDictTester(TrlTestCase):
@@ -928,10 +921,10 @@ def test_split_equal_chunks(self):
expected_x_chunks = torch.chunk(x, 3, dim=0)
expected_y_chunks = torch.chunk(y, 3, dim=0)
- self.assertEqual(len(result), 3)
+ assert len(result) == 3
for i in range(3):
- self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
- self.assertTrue(torch.equal(result[i]["y"], expected_y_chunks[i]))
+ assert torch.equal(result[i]["x"], expected_x_chunks[i])
+ assert torch.equal(result[i]["y"], expected_y_chunks[i])
def test_with_none_tensor(self):
x = torch.arange(12).reshape(6, 2)
@@ -940,10 +933,10 @@ def test_with_none_tensor(self):
result = split_tensor_dict(tensor_dict, 2)
expected_x_chunks = torch.chunk(x, 2, dim=0)
- self.assertEqual(len(result), 2)
+ assert len(result) == 2
for i in range(2):
- self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
- self.assertIsNone(result[i]["y"])
+ assert torch.equal(result[i]["x"], expected_x_chunks[i])
+ assert result[i]["y"] is None
def test_with_scalar(self):
x = torch.arange(12).reshape(6, 2)
@@ -952,10 +945,10 @@ def test_with_scalar(self):
result = split_tensor_dict(tensor_dict, 2)
expected_x_chunks = torch.chunk(x, 2, dim=0)
- self.assertEqual(len(result), 2)
+ assert len(result) == 2
for i in range(2):
- self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
- self.assertTrue(torch.equal(result[i]["y"], torch.tensor(1)))
+ assert torch.equal(result[i]["x"], expected_x_chunks[i])
+ assert torch.equal(result[i]["y"], torch.tensor(1))
class SplitPixelValuesByGridTester(TrlTestCase):
@@ -966,14 +959,14 @@ def test_split_correctly_0(self):
"pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3]
}
result = split_pixel_values_by_grid(batch)
- self.assertIsInstance(result["pixel_values"], list)
- self.assertEqual(len(result["pixel_values"]), 2)
- self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]))
- self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:]))
- self.assertIsInstance(result["image_grid_thw"], list)
- self.assertEqual(len(result["image_grid_thw"]), 2)
- self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])))
- self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]])))
+ assert isinstance(result["pixel_values"], list)
+ assert len(result["pixel_values"]) == 2
+ assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])
+ assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:])
+ assert isinstance(result["image_grid_thw"], list)
+ assert len(result["image_grid_thw"]) == 2
+ assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))
+ assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]]))
def test_split_correctly_1(self):
batch = {
@@ -982,19 +975,19 @@ def test_split_correctly_1(self):
"pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3]
}
result = split_pixel_values_by_grid(batch)
- self.assertIsInstance(result["pixel_values"], list)
- self.assertEqual(len(result["pixel_values"]), 2)
- self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]))
- self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12]))
- self.assertIsInstance(result["image_grid_thw"], list)
- self.assertEqual(len(result["image_grid_thw"]), 2)
- self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])))
- self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]])))
+ assert isinstance(result["pixel_values"], list)
+ assert len(result["pixel_values"]) == 2
+ assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])
+ assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12])
+ assert isinstance(result["image_grid_thw"], list)
+ assert len(result["image_grid_thw"]) == 2
+ assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))
+ assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]]))
def test_missing_keys(self):
batch = {"pixel_values": torch.tensor([1.0])}
result = split_pixel_values_by_grid(batch)
- self.assertEqual(result, batch)
+ assert result == batch
def test_mismatched_length(self):
batch = {
@@ -1002,7 +995,7 @@ def test_mismatched_length(self):
"num_images": [1, 1],
"pixel_values": torch.randn(3, 5), # Only 3 rows
}
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
split_pixel_values_by_grid(batch)
def test_multi_images(self):
@@ -1012,14 +1005,14 @@ def test_multi_images(self):
"pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3]
}
result = split_pixel_values_by_grid(batch)
- self.assertIsInstance(result["pixel_values"], list)
- self.assertEqual(len(result["pixel_values"]), 2)
- self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:2]))
- self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][2:]))
- self.assertIsInstance(result["image_grid_thw"], list)
- self.assertEqual(len(result["image_grid_thw"]), 2)
- self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]])))
- self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]])))
+ assert isinstance(result["pixel_values"], list)
+ assert len(result["pixel_values"]) == 2
+ assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:2])
+ assert torch.equal(result["pixel_values"][1], batch["pixel_values"][2:])
+ assert isinstance(result["image_grid_thw"], list)
+ assert len(result["image_grid_thw"]) == 2
+ assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]]))
+ assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
class TruncateWithProtectedTokensTester(TrlTestCase):
@@ -1032,7 +1025,7 @@ def test_basic_example(self):
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
expected_ids = [2, 3, 5]
- self.assertEqual(new_ids, expected_ids)
+ assert new_ids == expected_ids
def test_no_truncation_needed(self):
"""Test when target length equals current length."""
@@ -1042,7 +1035,7 @@ def test_no_truncation_needed(self):
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
- self.assertEqual(new_ids, prompt_ids)
+ assert new_ids == prompt_ids
def test_no_protected_tokens(self):
"""Test truncation with no protected tokens (normal right truncation)."""
@@ -1053,7 +1046,7 @@ def test_no_protected_tokens(self):
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
expected_ids = [3, 4, 5] # Last 3 tokens
- self.assertEqual(new_ids, expected_ids)
+ assert new_ids == expected_ids
def test_all_tokens_protected(self):
"""Test when all remaining tokens are protected."""
@@ -1064,7 +1057,7 @@ def test_all_tokens_protected(self):
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
expected_ids = [3, 4, 5]
- self.assertEqual(new_ids, expected_ids)
+ assert new_ids == expected_ids
def test_too_many_protected_tokens(self):
"""Test error when too many protected tokens for target length."""
@@ -1072,7 +1065,7 @@ def test_too_many_protected_tokens(self):
protected_tokens = [1, 2, 3, 4]
target_length = 3
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
def test_single_batch_single_token(self):
@@ -1083,7 +1076,7 @@ def test_single_batch_single_token(self):
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
- self.assertEqual(new_ids, prompt_ids)
+ assert new_ids == prompt_ids
def test_order_preservation(self):
"""Test that relative order is preserved."""
@@ -1097,7 +1090,7 @@ def test_order_preservation(self):
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
expected_ids = [2, 3, 30, 40]
- self.assertEqual(new_ids, expected_ids)
+ assert new_ids == expected_ids
class UnsplitPixelValuesByGridTester(TrlTestCase):
@@ -1108,14 +1101,14 @@ def test_unsplit_correctly(self):
image_grid_thw_merged = torch.cat(image_grid_thw, dim=0)
batch = {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "other_key": torch.tensor([1])}
result = unsplit_pixel_values_by_grid(batch)
- self.assertIsInstance(result["pixel_values"], torch.Tensor)
- self.assertTrue(torch.allclose(result["pixel_values"], pixel_values_merged))
- self.assertIsInstance(result["image_grid_thw"], torch.Tensor)
- self.assertTrue(torch.equal(result["image_grid_thw"], image_grid_thw_merged))
- self.assertIn("other_key", result)
+ assert isinstance(result["pixel_values"], torch.Tensor)
+ assert torch.allclose(result["pixel_values"], pixel_values_merged)
+ assert isinstance(result["image_grid_thw"], torch.Tensor)
+ assert torch.equal(result["image_grid_thw"], image_grid_thw_merged)
+ assert "other_key" in result
def test_no_op_if_not_list(self):
original = torch.randn(5, 3)
batch = {"pixel_values": original}
result = unsplit_pixel_values_by_grid(batch)
- self.assertTrue(torch.equal(result["pixel_values"], original))
+ assert torch.equal(result["pixel_values"], original)
diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py
index 08a302da41c..836cd124f16 100644
--- a/tests/test_vllm_client_server.py
+++ b/tests/test_vllm_client_server.py
@@ -27,28 +27,26 @@
class TestChunkList(TrlTestCase):
def test_even_split(self):
- self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 2), [[1, 2, 3], [4, 5, 6]])
+ assert chunk_list([1, 2, 3, 4, 5, 6], 2) == [[1, 2, 3], [4, 5, 6]]
def test_uneven_split(self):
- self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 4), [[1, 2], [3, 4], [5], [6]])
+ assert chunk_list([1, 2, 3, 4, 5, 6], 4) == [[1, 2], [3, 4], [5], [6]]
def test_more_chunks_than_elements(self):
- self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 8), [[1], [2], [3], [4], [5], [6], [], []])
+ assert chunk_list([1, 2, 3, 4, 5, 6], 8) == [[1], [2], [3], [4], [5], [6], [], []]
def test_n_equals_len(self):
- self.assertEqual(chunk_list([1, 2, 3], 3), [[1], [2], [3]])
+ assert chunk_list([1, 2, 3], 3) == [[1], [2], [3]]
def test_n_is_1(self):
- self.assertEqual(chunk_list([1, 2, 3], 1), [[1, 2, 3]])
+ assert chunk_list([1, 2, 3], 1) == [[1, 2, 3]]
def test_single_element_list(self):
- self.assertEqual(chunk_list([42], 2), [[42], []])
+ assert chunk_list([42], 2) == [[42], []]
def test_any_dtype(self):
- self.assertEqual(
- chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2),
- [[1, "two", 3.0], [{"four": 4}, ["f", "i", "v", "e"]]],
- )
+ assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == \
+ [[1, "two", 3.0], [{"four": 4}, ["f", "i", "v", "e"]]]
@pytest.mark.slow
@@ -77,14 +75,14 @@ def test_generate(self):
outputs = self.client.generate(prompts)["completion_ids"]
# Check that the output is a list
- self.assertIsInstance(outputs, list)
+ assert isinstance(outputs, list)
# Check that the number of generated sequences is equal to the number of prompts
- self.assertEqual(len(outputs), len(prompts))
+ assert len(outputs) == len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
- self.assertTrue(all(isinstance(tok, int) for tok in seq))
+ assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
@@ -93,18 +91,18 @@ def test_generate_with_params(self):
]
# Check that the output is a list
- self.assertIsInstance(outputs, list)
+ assert isinstance(outputs, list)
# Check that the number of generated sequences is 2 times the number of prompts
- self.assertEqual(len(outputs), 2 * len(prompts))
+ assert len(outputs) == 2 * len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
- self.assertTrue(all(isinstance(tok, int) for tok in seq))
+ assert all(isinstance(tok, int) for tok in seq)
# Check that the length of the generated sequences is less than or equal to 32
for seq in outputs:
- self.assertLessEqual(len(seq), 32)
+ assert len(seq) <= 32
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
@@ -153,14 +151,14 @@ def test_generate(self):
outputs = self.client.generate(prompts)["completion_ids"]
# Check that the output is a list
- self.assertIsInstance(outputs, list)
+ assert isinstance(outputs, list)
# Check that the number of generated sequences is equal to the number of prompts
- self.assertEqual(len(outputs), len(prompts))
+ assert len(outputs) == len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
- self.assertTrue(all(isinstance(tok, int) for tok in seq))
+ assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
@@ -169,18 +167,18 @@ def test_generate_with_params(self):
]
# Check that the output is a list
- self.assertIsInstance(outputs, list)
+ assert isinstance(outputs, list)
# Check that the number of generated sequences is 2 times the number of prompts
- self.assertEqual(len(outputs), 2 * len(prompts))
+ assert len(outputs) == 2 * len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
- self.assertTrue(all(isinstance(tok, int) for tok in seq))
+ assert all(isinstance(tok, int) for tok in seq)
# Check that the length of the generated sequences is less than or equal to 32
for seq in outputs:
- self.assertLessEqual(len(seq), 32)
+ assert len(seq) <= 32
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
@@ -231,14 +229,14 @@ def test_generate(self):
outputs = self.client.generate(prompts)["completion_ids"]
# Check that the output is a list
- self.assertIsInstance(outputs, list)
+ assert isinstance(outputs, list)
# Check that the number of generated sequences is equal to the number of prompts
- self.assertEqual(len(outputs), len(prompts))
+ assert len(outputs) == len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
- self.assertTrue(all(isinstance(tok, int) for tok in seq))
+ assert all(isinstance(tok, int) for tok in seq)
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
@@ -289,14 +287,14 @@ def test_generate(self):
outputs = self.client.generate(prompts)["completion_ids"]
# Check that the output is a list
- self.assertIsInstance(outputs, list)
+ assert isinstance(outputs, list)
# Check that the number of generated sequences is equal to the number of prompts
- self.assertEqual(len(outputs), len(prompts))
+ assert len(outputs) == len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
- self.assertTrue(all(isinstance(tok, int) for tok in seq))
+ assert all(isinstance(tok, int) for tok in seq)
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
@@ -345,8 +343,8 @@ def test_init_communicator_with_device_int(self):
# Test basic functionality
prompts = ["Hello, AI!"]
outputs = client.generate(prompts)["completion_ids"]
- self.assertIsInstance(outputs, list)
- self.assertEqual(len(outputs), len(prompts))
+ assert isinstance(outputs, list)
+ assert len(outputs) == len(prompts)
client.close_communicator()
@@ -358,8 +356,8 @@ def test_init_communicator_with_device_string(self):
# Test basic functionality
prompts = ["Hello, AI!"]
outputs = client.generate(prompts)["completion_ids"]
- self.assertIsInstance(outputs, list)
- self.assertEqual(len(outputs), len(prompts))
+ assert isinstance(outputs, list)
+ assert len(outputs) == len(prompts)
client.close_communicator()
@@ -374,8 +372,8 @@ def test_init_communicator_with_torch_device(self):
# Test basic functionality
prompts = ["Hello, AI!"]
outputs = client.generate(prompts)["completion_ids"]
- self.assertIsInstance(outputs, list)
- self.assertEqual(len(outputs), len(prompts))
+ assert isinstance(outputs, list)
+ assert len(outputs) == len(prompts)
client.close_communicator()
diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py
index 9d50b542a03..7a69455df33 100644
--- a/tests/test_xpo_trainer.py
+++ b/tests/test_xpo_trainer.py
@@ -65,7 +65,7 @@ def test_xpo_trainer_training(self, config_name):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_with_peft(self):
@@ -93,7 +93,7 @@ def test_training_with_peft(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_with_peft_and_ref_model(self):
@@ -122,7 +122,7 @@ def test_training_with_peft_and_ref_model(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_with_peft_model_and_peft_config(self):
@@ -153,7 +153,7 @@ def test_training_with_peft_model_and_peft_config(self):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_peft
def test_training_pre_pefted_model_implicit_ref(self):
@@ -182,7 +182,7 @@ def test_training_pre_pefted_model_implicit_ref(self):
trainer.train()
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
@require_llm_blender
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
@@ -213,4 +213,4 @@ def test_xpo_trainer_judge_training(self, config_name):
trainer.train()
# Check if training loss is available
- self.assertIn("train_loss", trainer.state.log_history[-1])
+ assert "train_loss" in trainer.state.log_history[-1]
From 283c05908c65f9ada9b42daa2a3153ab8c2be8ee Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 10:24:26 +0200
Subject: [PATCH 03/16] Fix style
---
tests/slow/test_grpo_slow.py | 8 ++--
tests/test_activation_offloading.py | 3 +-
tests/test_bco_trainer.py | 9 ++--
tests/test_cli_utils.py | 2 +-
tests/test_data_utils.py | 48 ++++++++++----------
tests/test_dataset_formatting.py | 12 +++--
tests/test_dpo_trainer.py | 69 +++++++++++++++--------------
tests/test_gkd_trainer.py | 17 +++----
tests/test_grpo_trainer.py | 20 ++++-----
tests/test_kto_trainer.py | 20 +++++----
tests/test_modeling_value_head.py | 22 +++++----
tests/test_online_dpo_trainer.py | 4 +-
tests/test_peft_models.py | 17 ++++---
tests/test_prm_trainer.py | 54 ++++++++++------------
tests/test_rewards.py | 2 +-
tests/test_rloo_trainer.py | 2 +-
tests/test_sft_trainer.py | 20 +++++----
tests/test_trainers_args.py | 28 ++++++------
tests/test_utils.py | 24 +++++-----
tests/test_vllm_client_server.py | 6 ++-
20 files changed, 202 insertions(+), 185 deletions(-)
diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py
index 5f4400b9115..3a453714daa 100644
--- a/tests/slow/test_grpo_slow.py
+++ b/tests/slow/test_grpo_slow.py
@@ -444,9 +444,9 @@ def dummy_reward_func(completions, **kwargs):
# Check if signature columns were set properly
if trainer._signature_columns is not None:
# Should include 'image' in signature columns for VLM processors
- assert "image" in \
- trainer._signature_columns, \
+ assert "image" in trainer._signature_columns, (
"Should include 'image' in signature columns for VLM"
+ )
# Should not emit any warnings about VLM incompatibility
incompatibility_warnings = [
@@ -455,9 +455,9 @@ def dummy_reward_func(completions, **kwargs):
if "does not support VLMs" in str(w_item.message)
or "not compatible" in str(w_item.message).lower()
]
- assert len(incompatibility_warnings) == \
- 0, \
+ assert len(incompatibility_warnings) == 0, (
f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}"
+ )
# Test passes if we get this far without exceptions
diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py
index b1e8f59d61d..6c9ae24b8ae 100644
--- a/tests/test_activation_offloading.py
+++ b/tests/test_activation_offloading.py
@@ -72,8 +72,9 @@ def test_offloading_with_peft_models(self) -> None:
for name_orig, grad_orig in grads_original:
for name_param, param in model.named_parameters():
if name_param == name_orig and param.requires_grad and param.grad is not None:
- assert torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), \
+ assert torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), (
f"Gradient mismatch for {name_orig}"
+ )
@require_torch_accelerator
def test_noop_manager_with_offloading(self):
diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py
index b1fdaf8d8cf..91f905eb628 100644
--- a/tests/test_bco_trainer.py
+++ b/tests/test_bco_trainer.py
@@ -14,6 +14,7 @@
from functools import partial
+import pytest
import torch
from accelerate import Accelerator
from datasets import load_dataset
@@ -26,7 +27,6 @@
from trl.trainer.bco_trainer import _process_tokens, _tokenize
from .testing_utils import TrlTestCase, require_no_wandb, require_sklearn
-import pytest
if is_peft_available():
@@ -363,8 +363,11 @@ def test_generate_during_eval_no_wandb(self):
report_to="none",
)
- with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
- " Please install `wandb` or `comet-ml` to resolve."):
+ with pytest.raises(
+ ValueError,
+ match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
+ " Please install `wandb` or `comet-ml` to resolve.",
+ ):
BCOTrainer(
model=model,
args=training_args,
diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py
index 01417c258bc..f8196dd5377 100644
--- a/tests/test_cli_utils.py
+++ b/tests/test_cli_utils.py
@@ -17,13 +17,13 @@
from dataclasses import dataclass
from unittest.mock import mock_open, patch
+import pytest
from datasets import DatasetDict, load_dataset
from trl import DatasetMixtureConfig, TrlParser, get_dataset
from trl.scripts.utils import DatasetConfig
from .testing_utils import TrlTestCase
-import pytest
@dataclass
diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py
index 88c63868094..abb27258e3f 100644
--- a/tests/test_data_utils.py
+++ b/tests/test_data_utils.py
@@ -675,46 +675,46 @@ class UnpairPreferenceDatasetTester(TrlTestCase):
def test_unpair_preference_dataset(self):
# Test that a paired dataset is correctly converted to unpaired
unpaired_dataset = unpair_preference_dataset(self.paired_dataset)
- assert unpaired_dataset.to_dict() == \
- self.unpaired_dataset.to_dict(), \
+ assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), (
"The paired dataset should be converted to unpaired."
+ )
def test_unpair_preference_dataset_dict(self):
# Test that a paired dataset dict is correctly converted to unpaired
paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict)
- assert unpaired_dataset_dict["abc"].to_dict() == \
- self.unpaired_dataset.to_dict(), \
+ assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), (
"The paired dataset should be converted to unpaired."
+ )
def test_maybe_unpair_preference_dataset(self):
# Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset
unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset)
- assert unpaired_dataset.to_dict() == \
- self.unpaired_dataset.to_dict(), \
+ assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), (
"The paired dataset should be converted to unpaired."
+ )
def test_maybe_unpair_preference_dataset_dict(self):
# Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset
paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict)
- assert unpaired_dataset_dict["abc"].to_dict() == \
- self.unpaired_dataset.to_dict(), \
+ assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), (
"The paired dataset should be converted to unpaired."
+ )
def test_maybe_unpair_preference_dataset_already_paired(self):
# Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset
unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset)
- assert unpaired_dataset.to_dict() == \
- self.unpaired_dataset.to_dict(), \
+ assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), (
"The unpaired dataset should remain unchanged."
+ )
def test_maybe_unpair_preference_dataset_dict_already_paired(self):
# Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset
unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset}))
- assert unpaired_dataset_dict["abc"].to_dict() == \
- self.unpaired_dataset.to_dict(), \
+ assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), (
"The unpaired dataset should remain unchanged."
+ )
class ExtractPromptTester(TrlTestCase):
@@ -755,44 +755,42 @@ class ExtractPromptTester(TrlTestCase):
def test_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational)
- assert example_extracted_prompt == \
- self.example_explicit_prompt_conversational, \
+ assert example_extracted_prompt == self.example_explicit_prompt_conversational, (
"The prompt is not correctly extracted from the dataset."
+ )
def test_maybe_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational)
- assert example_extracted_prompt == \
- self.example_explicit_prompt_conversational, \
+ assert example_extracted_prompt == self.example_explicit_prompt_conversational, (
"The prompt is not correctly extracted from the dataset."
+ )
def test_maybe_extract_prompt_conversational_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational)
- assert example_extracted_prompt == \
- self.example_explicit_prompt_conversational, \
+ assert example_extracted_prompt == self.example_explicit_prompt_conversational, (
"The prompt should remain unchanged."
+ )
def test_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard)
- assert example_extracted_prompt == \
- self.example_explicit_prompt_standard, \
+ assert example_extracted_prompt == self.example_explicit_prompt_standard, (
"The prompt is not correctly extracted from the dataset."
+ )
def test_maybe_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard)
- assert example_extracted_prompt == \
- self.example_explicit_prompt_standard, \
+ assert example_extracted_prompt == self.example_explicit_prompt_standard, (
"The prompt is not correctly extracted from the dataset."
+ )
def test_maybe_extract_prompt_standard_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard)
- assert example_extracted_prompt == \
- self.example_explicit_prompt_standard, \
- "The prompt should remain unchanged."
+ assert example_extracted_prompt == self.example_explicit_prompt_standard, "The prompt should remain unchanged."
class TestPackDatasetWrapped(TrlTestCase):
diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py
index cc132aedefb..78f2e7dfb1d 100644
--- a/tests/test_dataset_formatting.py
+++ b/tests/test_dataset_formatting.py
@@ -152,8 +152,10 @@ def test_example_with_setup_model(self):
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
- assert prompt == \
- "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
+ assert (
+ prompt
+ == "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
+ )
class CloneChatTemplateTestCase(TrlTestCase):
@@ -217,8 +219,10 @@ def test_apply_new_chat_template(self):
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
- assert prompt == \
- "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n"
+ assert (
+ prompt
+ == "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n"
+ )
def test_clone_with_sequence_classification_model(self):
# This tokenizer doesn't have a chat_template by default
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index f429726c20b..8c0a4efb250 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -17,6 +17,7 @@
from unittest.mock import MagicMock
import numpy as np
+import pytest
import torch
from datasets import Dataset, features, load_dataset
from parameterized import parameterized
@@ -40,7 +41,6 @@
from trl import DPOConfig, DPOTrainer, FDivergenceType
from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb
-import pytest
if is_vision_available():
@@ -85,12 +85,11 @@ def test_tokenize_row_no_truncation_no_special_tokens(self):
)
# Assert the correct output without truncation or special tokens
- assert result == \
- {
- "prompt_input_ids": [464, 6766, 318],
- "chosen_input_ids": [4171, 2], # eos_token added
- "rejected_input_ids": [4077, 2], # eos_token added
- }
+ assert result == {
+ "prompt_input_ids": [464, 6766, 318],
+ "chosen_input_ids": [4171, 2], # eos_token added
+ "rejected_input_ids": [4077, 2], # eos_token added
+ }
def test_tokenize_row_with_truncation(self):
# Define the input features
@@ -106,12 +105,11 @@ def test_tokenize_row_with_truncation(self):
)
# Assert the correct output with truncation applied
- assert result == \
- {
- "prompt_input_ids": [6766, 318], # truncated to the last 2 tokens
- "chosen_input_ids": [4171], # truncated to 1 token
- "rejected_input_ids": [4077], # truncated to 1 token
- }
+ assert result == {
+ "prompt_input_ids": [6766, 318], # truncated to the last 2 tokens
+ "chosen_input_ids": [4171], # truncated to 1 token
+ "rejected_input_ids": [4077], # truncated to 1 token
+ }
def test_tokenize_row_with_special_tokens(self):
# Define the input features
@@ -127,12 +125,11 @@ def test_tokenize_row_with_special_tokens(self):
)
# Assert the correct output with special tokens added
- assert result == \
- {
- "prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added
- "chosen_input_ids": [4171, 2], # eos_token added
- "rejected_input_ids": [4077, 2], # eos_token added
- }
+ assert result == {
+ "prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added
+ "chosen_input_ids": [4171, 2], # eos_token added
+ "rejected_input_ids": [4077, 2], # eos_token added
+ }
def test_tokenize_row_with_truncation_and_special_tokens(self):
# Define the input features
@@ -148,12 +145,11 @@ def test_tokenize_row_with_truncation_and_special_tokens(self):
)
# Assert the correct output with both truncation and special tokens
- assert result == \
- {
- "prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token
- "chosen_input_ids": [4171], # truncated to 1 token
- "rejected_input_ids": [4077], # truncated to 1 token
- }
+ assert result == {
+ "prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token
+ "chosen_input_ids": [4171], # truncated to 1 token
+ "rejected_input_ids": [4077], # truncated to 1 token
+ }
class DPOTrainerTester(TrlTestCase):
@@ -573,8 +569,11 @@ def test_dpo_trainer_generate_during_eval_no_wandb(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
- with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed."
- " Please install `wandb`, `mlflow` or `comet-ml` to resolve."):
+ with pytest.raises(
+ ValueError,
+ match="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed."
+ " Please install `wandb`, `mlflow` or `comet-ml` to resolve.",
+ ):
DPOTrainer(
model=self.model,
ref_model=None,
@@ -963,9 +962,10 @@ def test_dpo_trainer_dtype(self):
train_dataset=dummy_dataset["train"],
)
- assert "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " \
- "`torch.dtype` (e.g., 'float32'), but got -1." in \
- str(context.exception)
+ assert (
+ "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid "
+ "`torch.dtype` (e.g., 'float32'), but got -1." in str(context.exception)
+ )
training_args = DPOConfig(
output_dir=self.tmp_dir,
@@ -984,9 +984,10 @@ def test_dpo_trainer_dtype(self):
train_dataset=dummy_dataset["train"],
)
- assert "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " \
- "`torch.dtype` (e.g., 'float32'), but got -1." in \
- str(context.exception)
+ assert (
+ "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid "
+ "`torch.dtype` (e.g., 'float32'), but got -1." in str(context.exception)
+ )
def test_dpo_loss_alpha_div_f(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -1369,7 +1370,7 @@ def test_dpo_trainer_with_liger(self, beta, loss_type):
with torch.no_grad():
output = trainer.model(**model_inputs)
assert output is not None
- assert not ("loss" in output.keys())
+ assert "loss" not in output.keys()
def test_train_with_iterable_dataset(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py
index 11ddcaedc81..7bdec1d4553 100644
--- a/tests/test_gkd_trainer.py
+++ b/tests/test_gkd_trainer.py
@@ -70,8 +70,9 @@ def test_generate_on_policy_outputs_deterministic(self):
# Check if the generated texts start with the original prompts
for prompt, generated_text in zip(prompts, generated_texts):
- assert generated_text.startswith(prompt), \
+ assert generated_text.startswith(prompt), (
f"Generated text '{generated_text}' does not start with prompt '{prompt}'"
+ )
# Run the generation twice and check if the outputs are identical
outputs2 = GKDTrainer.generate_on_policy_outputs(
@@ -82,10 +83,10 @@ def test_generate_on_policy_outputs_deterministic(self):
# Check if the two generations are identical
assert torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical"
- assert torch.all(new_attention_mask.eq(new_attention_mask2)), \
+ assert torch.all(new_attention_mask.eq(new_attention_mask2)), (
"Attention masks for deterministic generations are not identical"
- assert torch.all(new_labels.eq(new_labels2)), \
- "Labels for deterministic generations are not identical"
+ )
+ assert torch.all(new_labels.eq(new_labels2)), "Labels for deterministic generations are not identical"
def test_generate_on_policy_outputs(self):
prompts = ["Hello, how are you?", "What's the weather like today?"]
@@ -134,7 +135,7 @@ def setUp(self):
def test_uniform_distribution(self):
logits = torch.ones(1, 1, self.vocab_size)
loss = GKDTrainer.generalized_jsd_loss(logits, logits)
- assert round(abs(loss.item()-0), 5) == 0
+ assert round(abs(loss.item() - 0), 5) == 0
def test_generalized_jsd_loss_edge_cases(self):
# Setup
@@ -146,14 +147,14 @@ def test_generalized_jsd_loss_edge_cases(self):
expected_loss_beta_1 = F.kl_div(
F.log_softmax(teacher_logits, dim=-1), F.softmax(student_logits, dim=-1), reduction="batchmean"
)
- assert round(abs(loss_beta_1.item()-expected_loss_beta_1.item()), 5) == 0
+ assert round(abs(loss_beta_1.item() - expected_loss_beta_1.item()), 5) == 0
# Case 2: beta = 0 (should be equivalent to KL(teacher || student))
loss_beta_0 = GKDTrainer.generalized_jsd_loss(student_logits, teacher_logits, beta=0)
expected_loss_beta_0 = F.kl_div(
F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction="batchmean"
)
- assert round(abs(loss_beta_0.item()-expected_loss_beta_0.item()), 5) == 0
+ assert round(abs(loss_beta_0.item() - expected_loss_beta_0.item()), 5) == 0
def test_output_shape(self):
loss = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits)
@@ -195,7 +196,7 @@ def test_symmetry(self):
def test_zero_loss_for_identical_inputs(self):
identical_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size)
loss = GKDTrainer.generalized_jsd_loss(identical_logits, identical_logits)
- assert round(abs(loss.item()-0), 6) == 0
+ assert round(abs(loss.item() - 0), 6) == 0
class GKDTrainerTester(TrlTestCase):
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index 9a442ee6102..f353df444cf 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -1922,19 +1922,19 @@ def test_update_with_inputs_different_seq_len(self):
# Check for new entry with seq len 3 in buffer
assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group
- assert [[1013, 1014, pad_token_id], [1015, 1016, 1017]] in buffered_completion_ids # excluded no-variance group
+ assert [
+ [1013, 1014, pad_token_id],
+ [1015, 1016, 1017],
+ ] in buffered_completion_ids # excluded no-variance group
# Check that sampled outputs contain one group with prompt_ids starting with a pad token
assert [
- [pad_token_id, 101, 102],
- [pad_token_id, 102, 103],
- ] \
- in output_prompt_ids \
- or [
- [pad_token_id, 104, 105],
- [pad_token_id, 106, 107],
- ] \
- in output_prompt_ids
+ [pad_token_id, 101, 102],
+ [pad_token_id, 102, 103],
+ ] in output_prompt_ids or [
+ [pad_token_id, 104, 105],
+ [pad_token_id, 106, 107],
+ ] in output_prompt_ids
@pytest.mark.low_priority
diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py
index ac68a00d9a7..da749abc62d 100644
--- a/tests/test_kto_trainer.py
+++ b/tests/test_kto_trainer.py
@@ -13,6 +13,7 @@
# limitations under the License.
+import pytest
import torch
from datasets import load_dataset
from parameterized import parameterized
@@ -23,7 +24,6 @@
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
from .testing_utils import TrlTestCase, require_no_wandb
-import pytest
class KTOTrainerTester(TrlTestCase):
@@ -167,12 +167,11 @@ def test_tokenize_and_process_tokens(self):
# the last batch remains unaltered. This is a rare scenario that does not impact the training
# process, so we exclude it from testing by iterating only up to len - 1.
for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1):
- assert tokenized_dataset["prompt_input_ids"][i] == \
- tokenized_kl_dataset["prompt_input_ids"][i]
- assert tokenized_dataset["prompt_attention_mask"][i] == \
- tokenized_kl_dataset["prompt_attention_mask"][i]
- assert tokenized_dataset["answer_input_ids"][i] != \
- tokenized_kl_dataset["answer_input_ids"][i]
+ assert tokenized_dataset["prompt_input_ids"][i] == tokenized_kl_dataset["prompt_input_ids"][i]
+ assert (
+ tokenized_dataset["prompt_attention_mask"][i] == tokenized_kl_dataset["prompt_attention_mask"][i]
+ )
+ assert tokenized_dataset["answer_input_ids"][i] != tokenized_kl_dataset["answer_input_ids"][i]
fn_kwargs = {
"prefix": "",
@@ -295,8 +294,11 @@ def test_kto_trainer_generate_during_eval_no_wandb(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")
- with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
- " Please install `wandb` or `comet-ml` to resolve."):
+ with pytest.raises(
+ ValueError,
+ match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
+ " Please install `wandb` or `comet-ml` to resolve.",
+ ):
KTOTrainer(
model=self.model,
ref_model=None,
diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py
index 355dcb2a1c8..0aa56901f82 100644
--- a/tests/test_modeling_value_head.py
+++ b/tests/test_modeling_value_head.py
@@ -142,8 +142,8 @@ def test_from_save_transformers_sharded(self):
# Check if the weights are the same
for key in transformers_model.state_dict():
assert torch.allclose(
- transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
- )
+ transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
+ )
def test_from_save_transformers(self):
"""
@@ -163,8 +163,8 @@ def test_from_save_transformers(self):
# Check if the weights are the same
for key in transformers_model.state_dict():
assert torch.allclose(
- transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
- )
+ transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
+ )
# Check if the trl model has the same keys as the transformers model
# except the v_head
@@ -175,8 +175,9 @@ def test_from_save_transformers(self):
assert torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key])
# check if they have the same modules
- assert set(transformers_model_from_save.state_dict().keys()) == \
- set(transformers_model.state_dict().keys())
+ assert set(transformers_model_from_save.state_dict().keys()) == set(
+ transformers_model.state_dict().keys()
+ )
class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase):
@@ -261,8 +262,9 @@ def test_transformers_bf16_kwargs(self):
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
- assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), \
+ assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), (
"Can't test the model because it doesn't have any of the expected lm_head namings"
+ )
for lm_head_naming in lm_head_namings:
if hasattr(trl_model.pretrained_model, lm_head_naming):
@@ -287,8 +289,9 @@ def test_push_to_hub(self):
assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
for name, param in model.state_dict().items():
- assert torch.allclose(param, model_from_pretrained.state_dict()[name]), \
+ assert torch.allclose(param, model_from_pretrained.state_dict()[name]), (
f"Parameter {name} is not the same after push_to_hub and from_pretrained"
+ )
class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase):
@@ -378,8 +381,9 @@ def test_push_to_hub(self):
assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
for name, param in model.state_dict().items():
- assert torch.allclose(param, model_from_pretrained.state_dict()[name]), \
+ assert torch.allclose(param, model_from_pretrained.state_dict()[name]), (
f"Parameter {name} is not the same after push_to_hub and from_pretrained"
+ )
def test_transformers_bf16_kwargs(self):
r"""
diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py
index b8eb83a0f04..54e00447593 100644
--- a/tests/test_online_dpo_trainer.py
+++ b/tests/test_online_dpo_trainer.py
@@ -478,8 +478,8 @@ def simple_reward_func(prompts, completions, completion_ids, **kwargs):
assert "train_loss" in trainer.state.log_history[-1]
assert len(trainer.reward_funcs) == 2
assert trainer.reward_weights is not None
- assert round(abs(trainer.reward_weights[0].item()-0.7), 5) == 0
- assert round(abs(trainer.reward_weights[1].item()-0.3), 5) == 0
+ assert round(abs(trainer.reward_weights[0].item() - 0.7), 5) == 0
+ assert round(abs(trainer.reward_weights[1].item() - 0.3), 5) == 0
@require_vision
diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py
index 3b68be8b817..ac62174db00 100644
--- a/tests/test_peft_models.py
+++ b/tests/test_peft_models.py
@@ -136,17 +136,21 @@ def test_save_pretrained_peft(self):
model.save_pretrained(self.tmp_dir)
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
- assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), \
+ assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), (
f"{self.tmp_dir}/adapter_model.safetensors does not exist"
- assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist"
+ )
+ assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), (
+ f"{self.tmp_dir}/adapter_config.json does not exist"
+ )
# check also for `pytorch_model.bin` and make sure it only contains `v_head` weights
assert os.path.exists(f"{self.tmp_dir}/pytorch_model.bin"), f"{self.tmp_dir}/pytorch_model.bin does not exist"
# check that only keys that starts with `v_head` are in the dict
maybe_v_head = torch.load(f"{self.tmp_dir}/pytorch_model.bin", weights_only=True)
- assert all(k.startswith("v_head") for k in maybe_v_head.keys()), \
+ assert all(k.startswith("v_head") for k in maybe_v_head.keys()), (
f"keys in {self.tmp_dir}/pytorch_model.bin do not start with `v_head`"
+ )
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir)
@@ -167,9 +171,12 @@ def test_load_pretrained_peft(self):
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir)
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
- assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), \
+ assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), (
f"{self.tmp_dir}/adapter_model.safetensors does not exist"
- assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist"
+ )
+ assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), (
+ f"{self.tmp_dir}/adapter_config.json does not exist"
+ )
# check all the weights are the same
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py
index 76398519836..2a8083d7492 100644
--- a/tests/test_prm_trainer.py
+++ b/tests/test_prm_trainer.py
@@ -75,11 +75,10 @@ def test_tokenize_row_no_truncation(self):
is_eval=False,
)
- assert result == \
- {
- "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
- "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0],
- }
+ assert result == {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0],
+ }
def test_tokenize_row_train_on_last_step_only(self):
# Define the input features
@@ -100,11 +99,10 @@ def test_tokenize_row_train_on_last_step_only(self):
is_eval=False,
)
- assert result == \
- {
- "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
- "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0],
- }
+ assert result == {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0],
+ }
def test_tokenize_row_prompt_truncation(self):
# Define the input features
@@ -126,11 +124,10 @@ def test_tokenize_row_prompt_truncation(self):
is_eval=False,
)
- assert result == \
- {
- "input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
- "labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0],
- }
+ assert result == {
+ "input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
+ "labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0],
+ }
def test_tokenize_row_completion_truncation(self):
# Define the input features
@@ -152,11 +149,10 @@ def test_tokenize_row_completion_truncation(self):
is_eval=False,
)
- assert result == \
- {
- "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11],
- "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100],
- }
+ assert result == {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100],
+ }
def test_tokenize_row_prompt_completion_truncation(self):
# Define the input features
@@ -178,11 +174,10 @@ def test_tokenize_row_prompt_completion_truncation(self):
is_eval=False,
)
- assert result == \
- {
- "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030],
- "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1],
- }
+ assert result == {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1],
+ }
def test_tokenize_row_multi_token_separator(self):
# Define the input features
@@ -204,11 +199,10 @@ def test_tokenize_row_multi_token_separator(self):
is_eval=False,
)
- assert result == \
- {
- "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030],
- "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0],
- }
+ assert result == {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0],
+ }
class PRMTrainerTester(TrlTestCase):
diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index aac6aabca0c..21827b6b4ea 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -85,7 +85,7 @@ def test_soft_overlong_punishment_intermediate_completion(self):
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 90] # 90 is between 80 and 100
rewards = reward_fn(completion_ids)
- assert round(abs(rewards[0]--0.5), 4) == 0
+ assert round(abs(rewards[0] - -0.5), 4) == 0
if __name__ == "__main__":
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index 2ddb51248ce..f79c4ca3f08 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -15,6 +15,7 @@
import unittest
from unittest.mock import patch
+import pytest
import torch
from datasets import load_dataset
from parameterized import parameterized
@@ -30,7 +31,6 @@
from trl import RLOOConfig, RLOOTrainer
from .testing_utils import TrlTestCase, require_vllm
-import pytest
if is_peft_available():
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index 2a558d6093f..b60ea48bfaa 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -1504,24 +1504,26 @@ def test_peft_model_with_quantization(self):
lora_params_after.append(name)
# LoRA parameters should remain trainable
- assert len(trainable_params_after) > 0, \
- f"PeftModel should still have trainable parameters after SFTTrainer initialization. " \
- f"Found {len(trainable_params_after)} trainable params. " \
+ assert len(trainable_params_after) > 0, (
+ f"PeftModel should still have trainable parameters after SFTTrainer initialization. "
+ f"Found {len(trainable_params_after)} trainable params. "
f"This test fails without the fix for issue #3926."
+ )
- assert len(lora_params_after) > 0, \
- f"LoRA adapter parameters should remain trainable. " \
+ assert len(lora_params_after) > 0, (
+ f"LoRA adapter parameters should remain trainable. "
f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original."
+ )
# Ensure the parameter counts are preserved (no additional freezing occurred)
- assert len(trainable_params_before) == \
- len(trainable_params_after), \
+ assert len(trainable_params_before) == len(trainable_params_after), (
"Number of trainable parameters should not change after SFTTrainer initialization"
+ )
# Verify that all original LoRA parameters are still trainable
- assert set(lora_params_before) == \
- set(lora_params_after), \
+ assert set(lora_params_before) == set(lora_params_after), (
"All original LoRA parameters should remain trainable after SFTTrainer initialization"
+ )
@require_peft
def test_prompt_tuning_peft_model(self):
diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py
index 35a8da57cd3..b9a809ba071 100644
--- a/tests/test_trainers_args.py
+++ b/tests/test_trainers_args.py
@@ -84,8 +84,8 @@ def test_bco(self):
assert trainer.args.padding_value == -99
assert trainer.args.truncation_mode == "keep_start"
# self.assertEqual(trainer.args.generate_during_eval, True)
- assert trainer.args.is_encoder_decoder == True
- assert trainer.args.precompute_ref_log_probs == True
+ assert trainer.args.is_encoder_decoder
+ assert trainer.args.precompute_ref_log_probs
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.dataset_num_proc == 4
@@ -123,14 +123,14 @@ def test_cpo(self):
assert trainer.args.beta == 0.5
assert trainer.args.label_smoothing == 0.5
assert trainer.args.loss_type == "hinge"
- assert trainer.args.disable_dropout == False
+ assert not trainer.args.disable_dropout
assert trainer.args.cpo_alpha == 0.5
assert trainer.args.simpo_gamma == 0.2
assert trainer.args.label_pad_token_id == -99
assert trainer.args.padding_value == -99
assert trainer.args.truncation_mode == "keep_start"
# self.assertEqual(trainer.args.generate_during_eval, True)
- assert trainer.args.is_encoder_decoder == True
+ assert trainer.args.is_encoder_decoder
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.dataset_num_proc == 4
@@ -183,16 +183,16 @@ def test_dpo(self):
assert trainer.args.max_length == 256
assert trainer.args.max_prompt_length == 64
assert trainer.args.max_completion_length == 64
- assert trainer.args.disable_dropout == False
+ assert not trainer.args.disable_dropout
# self.assertEqual(trainer.args.generate_during_eval, True)
- assert trainer.args.precompute_ref_log_probs == True
+ assert trainer.args.precompute_ref_log_probs
assert trainer.args.dataset_num_proc == 4
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.model_adapter_name == "dummy_adapter"
assert trainer.args.ref_adapter_name == "dummy_adapter"
- assert trainer.args.reference_free == True
- assert trainer.args.force_use_ref_model == True
+ assert trainer.args.reference_free
+ assert trainer.args.force_use_ref_model
assert trainer.args.f_divergence_type == FDivergenceType.JS_DIVERGENCE
assert trainer.args.f_alpha_divergence_coef == 0.5
# self.assertEqual(trainer.args.sync_ref_model, True)
@@ -240,8 +240,8 @@ def test_kto(self):
assert trainer.args.padding_value == -99
assert trainer.args.truncation_mode == "keep_start"
# self.assertEqual(trainer.args.generate_during_eval, True)
- assert trainer.args.is_encoder_decoder == True
- assert trainer.args.precompute_ref_log_probs == True
+ assert trainer.args.is_encoder_decoder
+ assert trainer.args.precompute_ref_log_probs
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.dataset_num_proc == 4
@@ -323,7 +323,7 @@ def test_orpo(self):
assert trainer.args.max_prompt_length == 64
assert trainer.args.max_completion_length == 64
assert trainer.args.beta == 0.5
- assert trainer.args.disable_dropout == False
+ assert not trainer.args.disable_dropout
assert trainer.args.label_pad_token_id == -99
def test_reward(self):
@@ -363,14 +363,14 @@ def test_sft(self):
)
trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset)
assert trainer.args.dataset_text_field == "dummy_text_field"
- assert trainer.args.packing == True
+ assert trainer.args.packing
assert trainer.args.max_length == 256
assert trainer.args.dataset_num_proc == 4
assert trainer.args.neftune_noise_alpha == 0.1
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
assert "append_concat_token" in trainer.args.dataset_kwargs
- assert trainer.args.dataset_kwargs["append_concat_token"] == True
- assert trainer.args.eval_packing == True
+ assert trainer.args.dataset_kwargs["append_concat_token"]
+ assert trainer.args.eval_packing
@parameterized.expand([(False,), (True,)])
def test_xpo(self, alpha_list):
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 2bfa9bff0c2..0dfa00d06e9 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -17,6 +17,7 @@
from unittest.mock import patch
import numpy as np
+import pytest
import torch
from datasets import load_dataset
from parameterized import parameterized
@@ -47,7 +48,6 @@
)
from .testing_utils import TrlTestCase, require_rich
-import pytest
if is_peft_available():
@@ -292,8 +292,9 @@ def test_data_collator_for_chatml(self):
# Verify labels are -100 for non-assistant parts
prompt_length = len(prompt_only)
- assert all(label == self.ignore_index for label in labels[:prompt_length]), \
+ assert all(label == self.ignore_index for label in labels[:prompt_length]), (
"Labels should be ignore_index for prompt tokens"
+ )
# Verify labels match assistant response after prompt
# Add a filter to remove any trailing tokens after the first <|im_end|>
@@ -309,18 +310,15 @@ def test_data_collator_for_chatml(self):
response_labels.append(label)
if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"):
break
- assert response_labels == \
- last_assistant_response_tokens, \
- "Labels should match assistant response tokens"
+ assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens"
# Verify there isn't a generation prompt at the end
generation_prompt = "<|im_start|>assistant"
- assert not decoded_input.strip().endswith(generation_prompt), \
+ assert not decoded_input.strip().endswith(generation_prompt), (
f"Input should not end with generation prompt '{generation_prompt}'"
+ )
- assert response_labels == \
- last_assistant_response_tokens, \
- "Labels should match assistant response tokens"
+ assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens"
class TestBatchGeneration(TrlTestCase):
@@ -397,7 +395,7 @@ def test_token_classification_task(self):
)
expected_accuracy = 0.5 # 2 matches, 2 mismatches
result = compute_accuracy(eval_pred)
- assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0
+ assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0
def test_token_classification_task_with_ignored_tokens_0(self):
eval_pred = (
@@ -411,7 +409,7 @@ def test_token_classification_task_with_ignored_tokens_0(self):
)
expected_accuracy = 1.0 # All non-ignored tokens match
result = compute_accuracy(eval_pred)
- assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0
+ assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0
def test_token_classification_task_with_ignored_tokens_1(self):
eval_pred = (
@@ -425,7 +423,7 @@ def test_token_classification_task_with_ignored_tokens_1(self):
)
expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored
result = compute_accuracy(eval_pred)
- assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0
+ assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0
def test_rewards_comparison_task(self):
eval_pred = (
@@ -443,7 +441,7 @@ def test_rewards_comparison_task(self):
with self.assertLogs("trl.trainer.utils", level="WARNING") as cm:
result = compute_accuracy(eval_pred)
- assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0
+ assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0
expected_warning = (
"There are 1 out of 3 instances where the predictions for both options are equal. "
"These instances are ignored in the accuracy computation."
diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py
index 836cd124f16..62d22583091 100644
--- a/tests/test_vllm_client_server.py
+++ b/tests/test_vllm_client_server.py
@@ -45,8 +45,10 @@ def test_single_element_list(self):
assert chunk_list([42], 2) == [[42], []]
def test_any_dtype(self):
- assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == \
- [[1, "two", 3.0], [{"four": 4}, ["f", "i", "v", "e"]]]
+ assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == [
+ [1, "two", 3.0],
+ [{"four": 4}, ["f", "i", "v", "e"]],
+ ]
@pytest.mark.slow
From 7c463f6f94da1b780fcfc4cfc2bf79cf80433586 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 11:25:07 +0200
Subject: [PATCH 04/16] Fix pytest.raises with match arg
---
tests/test_cli_utils.py | 11 +++--------
tests/test_dpo_trainer.py | 23 +++++++++--------------
tests/test_grpo_trainer.py | 4 +---
tests/test_rloo_trainer.py | 4 +---
4 files changed, 14 insertions(+), 28 deletions(-)
diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py
index f8196dd5377..708640c55fd 100644
--- a/tests/test_cli_utils.py
+++ b/tests/test_cli_utils.py
@@ -45,9 +45,8 @@ def test_init_without_config_field(self):
def test_init_with_config_field(self):
"""Test initialization with a 'config' field in the dataclass (should raise ValueError)."""
- with pytest.raises(ValueError) as context:
+ with pytest.raises(ValueError, match="has a field named 'config'"):
TrlParser(dataclass_types=[InvalidDataclass])
- assert "has a field named 'config'" in str(context.exception)
@patch("builtins.open", mock_open(read_data="env:\n VAR1: value1\n VAR2: value2\narg1: 2"))
@patch("yaml.safe_load")
@@ -105,11 +104,9 @@ def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load):
args = ["--arg1", "2", "--arg2", "value", "--config", "config.yaml"]
- with pytest.raises(ValueError) as context:
+ with pytest.raises(ValueError, match="`env` field should be a dict in the YAML file."):
parser.parse_args_and_config(args)
- assert str(context.exception) == "`env` field should be a dict in the YAML file."
-
def test_parse_args_and_config_without_config(self):
"""Test parse_args_and_config without the `--config` argument."""
parser = TrlParser(dataclass_types=[MyDataclass])
@@ -352,11 +349,9 @@ def test_dataset_mixture_with_test_split(self):
def test_empty_dataset_mixture_raises_error(self):
mixture_config = DatasetMixtureConfig(datasets=[])
- with pytest.raises(ValueError) as context:
+ with pytest.raises(ValueError, match="No datasets were loaded"):
get_dataset(mixture_config)
- assert "No datasets were loaded" in str(context.exception)
-
def test_mixture_multiple_different_configs(self):
dataset_config1 = DatasetConfig(
path="trl-internal-testing/zen", name="conversational_preference", split="train", columns=["prompt"]
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index 8c0a4efb250..0e4afd4a1bd 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -336,13 +336,12 @@ def test_train_with_multiple_loss_types(self):
assert "nll_loss" in metrics # SFT loss should be computed
def test_wrong_loss_weights_length(self):
- with pytest.raises(ValueError) as context:
+ with pytest.raises(ValueError, match="Length of loss_weights list"):
DPOConfig(
output_dir=self.tmp_dir,
loss_type=["sigmoid", "bco_pair"],
loss_weights=[1.0, 0.5, 0.1], # Wrong length
)
- assert "Length of loss_weights list" in str(context.exception)
@parameterized.expand([(None,), (0.5,)])
def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha):
@@ -954,7 +953,10 @@ def test_dpo_trainer_dtype(self):
report_to="none",
)
- with pytest.raises(ValueError) as context:
+ with pytest.raises(
+ ValueError,
+ match="Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1.",
+ ):
_ = DPOTrainer(
model=self.model_id,
processing_class=self.tokenizer,
@@ -962,11 +964,6 @@ def test_dpo_trainer_dtype(self):
train_dataset=dummy_dataset["train"],
)
- assert (
- "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid "
- "`torch.dtype` (e.g., 'float32'), but got -1." in str(context.exception)
- )
-
training_args = DPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
@@ -975,7 +972,10 @@ def test_dpo_trainer_dtype(self):
report_to="none",
)
- with pytest.raises(ValueError) as context:
+ with pytest.raises(
+ ValueError,
+ match="Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1.",
+ ):
_ = DPOTrainer(
model=self.model_id,
ref_model=self.model_id,
@@ -984,11 +984,6 @@ def test_dpo_trainer_dtype(self):
train_dataset=dummy_dataset["train"],
)
- assert (
- "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid "
- "`torch.dtype` (e.g., 'float32'), but got -1." in str(context.exception)
- )
-
def test_dpo_loss_alpha_div_f(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index f353df444cf..785d8426f68 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -1645,7 +1645,7 @@ def test_mismatched_reward_processing_classes_length(self):
training_args = GRPOConfig(output_dir=self.tmp_dir, report_to="none")
- with pytest.raises(ValueError) as context:
+ with pytest.raises(ValueError, match="must match"):
GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_models,
@@ -1654,8 +1654,6 @@ def test_mismatched_reward_processing_classes_length(self):
train_dataset=dataset,
)
- assert "must match" in str(context.exception)
-
def test_correct_reward_processing_classes_list(self):
"""Test that correct list of reward_processing_classes works properly."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index f79c4ca3f08..23f709d466a 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -1353,7 +1353,7 @@ def test_mismatched_reward_processing_classes_length(self):
training_args = RLOOConfig(output_dir=self.tmp_dir, report_to="none")
- with pytest.raises(ValueError) as context:
+ with pytest.raises(ValueError, match="must match"):
RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_models,
@@ -1362,8 +1362,6 @@ def test_mismatched_reward_processing_classes_length(self):
train_dataset=dataset,
)
- assert "must match" in str(context.exception)
-
def test_correct_reward_processing_classes_list(self):
"""Test that correct list of reward_processing_classes works properly."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
From 82107e4e17ad61cb664e75f4f3d5502ce5aa9fdf Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 12:43:17 +0200
Subject: [PATCH 05/16] Use re.escape for regex special characters
---
tests/test_dpo_trainer.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index 0e4afd4a1bd..58d04930b80 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
import sys
import unittest
from unittest.mock import MagicMock
@@ -955,7 +956,9 @@ def test_dpo_trainer_dtype(self):
with pytest.raises(
ValueError,
- match="Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1.",
+ match=re.escape(
+ "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1."
+ ),
):
_ = DPOTrainer(
model=self.model_id,
@@ -974,7 +977,9 @@ def test_dpo_trainer_dtype(self):
with pytest.raises(
ValueError,
- match="Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1.",
+ match=re.escape(
+ "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1."
+ ),
):
_ = DPOTrainer(
model=self.model_id,
From d93dd49826aae03c6210beea898d49da069c9407 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 13:03:52 +0200
Subject: [PATCH 06/16] Remove unittest from TrlTestCase
---
tests/testing_utils.py | 19 +++++--------------
1 file changed, 5 insertions(+), 14 deletions(-)
diff --git a/tests/testing_utils.py b/tests/testing_utils.py
index 85026a53947..61a7935c544 100644
--- a/tests/testing_utils.py
+++ b/tests/testing_utils.py
@@ -14,13 +14,12 @@
import functools
import random
-import shutil
import signal
-import tempfile
import unittest
import warnings
import psutil
+import pytest
import torch
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
from transformers.testing_utils import torch_device
@@ -119,18 +118,10 @@ def judge(self, prompts, completions, shuffle_order=True, return_scores=False):
return [random.random() for _ in range(len(prompts))]
-class TrlTestCase(unittest.TestCase):
- """
- Base test case for TRL tests. Sets up a temporary directory for testing.
- """
-
- def setUp(self):
- super().setUp()
- self.tmp_dir = tempfile.mkdtemp()
-
- def tearDown(self):
- shutil.rmtree(self.tmp_dir)
- super().tearDown()
+class TrlTestCase:
+ @pytest.fixture(autouse=True)
+ def set_tmp_dir(self, tmp_path):
+ self.tmp_dir = str(tmp_path)
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable:
From 56adebe96a9c33e8c1cf47d3e2c63a3c860101ce Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 13:41:26 +0200
Subject: [PATCH 07/16] Replace setUp/tearDown
---
tests/slow/test_dpo_slow.py | 6 ++--
tests/slow/test_grpo_slow.py | 6 ++--
tests/slow/test_sft_slow.py | 6 ++--
tests/test_callbacks.py | 12 +++-----
tests/test_collators.py | 3 +-
tests/test_core.py | 3 +-
tests/test_cpo_trainer.py | 3 +-
tests/test_dataset_formatting.py | 6 ++--
tests/test_dpo_trainer.py | 6 ++--
tests/test_gkd_trainer.py | 8 ++---
tests/test_grpo_trainer.py | 8 ++---
tests/test_kto_trainer.py | 3 +-
...test_modeling_geometric_mixture_wrapper.py | 3 +-
tests/test_modeling_value_head.py | 12 +++-----
tests/test_nash_md_trainer.py | 3 +-
tests/test_online_dpo_trainer.py | 3 +-
tests/test_orpo_trainer.py | 3 +-
tests/test_peft_models.py | 3 +-
tests/test_ppo_trainer.py | 3 +-
tests/test_prm_trainer.py | 6 ++--
tests/test_rich_progress_callback.py | 3 +-
tests/test_utils.py | 9 ++----
tests/test_vllm_client_server.py | 30 +++++++------------
tests/test_xpo_trainer.py | 3 +-
24 files changed, 52 insertions(+), 99 deletions(-)
diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py
index e24362fbc88..30f7ffacb18 100644
--- a/tests/slow/test_dpo_slow.py
+++ b/tests/slow/test_dpo_slow.py
@@ -38,8 +38,7 @@
@require_torch_accelerator
@require_peft
class DPOTrainerSlowTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
self.peft_config = LoraConfig(
lora_alpha=16,
@@ -50,11 +49,10 @@ def setUp(self):
)
self.max_length = 128
- def tearDown(self):
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
- super().tearDown()
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS)))
def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py
index 3a453714daa..5c1164fa420 100644
--- a/tests/slow/test_grpo_slow.py
+++ b/tests/slow/test_grpo_slow.py
@@ -55,17 +55,15 @@
@pytest.mark.slow
@require_torch_accelerator
class GRPOTrainerSlowTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test")
self.max_length = 128
- def tearDown(self):
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
- super().tearDown()
@parameterized.expand(MODELS_TO_TEST)
@require_liger_kernel
diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py
index 7e673a9457c..d0811f5db40 100755
--- a/tests/slow/test_sft_slow.py
+++ b/tests/slow/test_sft_slow.py
@@ -45,8 +45,7 @@
@require_torch_accelerator
@require_peft
class SFTTrainerSlowTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
self.max_length = 128
@@ -58,11 +57,10 @@ def setUp(self):
task_type="CAUSAL_LM",
)
- def tearDown(self):
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
- super().tearDown()
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
def test_sft_trainer_str(self, model_name, packing):
diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py
index 501641b5df9..448e11bed8f 100644
--- a/tests/test_callbacks.py
+++ b/tests/test_callbacks.py
@@ -67,8 +67,7 @@ def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processi
class WinRateCallbackTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -226,8 +225,7 @@ def test_lora(self):
class LogCompletionsCallbackTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -321,8 +319,7 @@ def test_basic_comet(self):
@require_mergekit
class MergeModelCallbackTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
@@ -378,8 +375,7 @@ def test_every_checkpoint(self):
class BEMACallbackTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer.pad_token = self.tokenizer.eos_token
diff --git a/tests/test_collators.py b/tests/test_collators.py
index d798f29d54d..3159184f558 100644
--- a/tests/test_collators.py
+++ b/tests/test_collators.py
@@ -21,8 +21,7 @@
class TestDataCollatorForPreference(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.collator = DataCollatorForPreference(pad_token_id=0)
def assertTensorEqual(self, tensor1, tensor2):
diff --git a/tests/test_core.py b/tests/test_core.py
index 78e57c64b30..45449e5815b 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -25,8 +25,7 @@ class CoreTester(TrlTestCase):
A wrapper class for testing core utils functions
"""
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.test_input = torch.Tensor([1, 2, 3, 4])
self.test_mask = torch.Tensor([0, 1, 1, 0])
self.test_input_unmasked = self.test_input[1:3]
diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py
index c0edd771c5e..4f36be3bc84 100644
--- a/tests/test_cpo_trainer.py
+++ b/tests/test_cpo_trainer.py
@@ -26,8 +26,7 @@
class CPOTrainerTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py
index 78f2e7dfb1d..49b9b460a86 100644
--- a/tests/test_dataset_formatting.py
+++ b/tests/test_dataset_formatting.py
@@ -24,8 +24,7 @@
class DatasetFormattingTestCase(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.llama_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-MistralForCausalLM-0.1")
self.chatml_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -118,8 +117,7 @@ def test_get_formatting_func_from_dataset_with_unknown_format(self):
class SetupChatFormatTestCase(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
# remove built-in chat_template to simulate a model having no chat_template
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index 58d04930b80..db0f96e9fb2 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -49,8 +49,7 @@
class TestTokenizeRow(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
# Set up the mock tokenizer with specific behaviors
self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
self.tokenizer.bos_token_id = 0
@@ -154,8 +153,7 @@ def test_tokenize_row_with_truncation_and_special_tokens(self):
class DPOTrainerTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py
index 7bdec1d4553..a1b62cd993e 100644
--- a/tests/test_gkd_trainer.py
+++ b/tests/test_gkd_trainer.py
@@ -29,7 +29,7 @@
class TestGKDTrainer(TrlTestCase):
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
cls.tokenizer.pad_token = cls.tokenizer.eos_token
@@ -124,8 +124,7 @@ def test_generate_on_policy_outputs(self):
class TestGeneralizedJSDLoss(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.batch_size = 2
self.seq_length = 3
self.vocab_size = 5
@@ -200,8 +199,7 @@ def test_zero_loss_for_identical_inputs(self):
class GKDTrainerTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.teacher_model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index 785d8426f68..30a99199ee9 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -1712,8 +1712,8 @@ def test_single_reward_model_with_single_processing_class(self):
@pytest.mark.low_priority
-class TestReplayBuffer(unittest.TestCase):
- def setUp(self):
+class TestReplayBuffer:
+ def setup_method(self):
self.replay_buffer = ReplayBuffer(max_size=5)
def test_add(self):
@@ -1780,8 +1780,8 @@ def test_sample(self):
@pytest.mark.low_priority
-class TestUpdateWithReplayBuffer(unittest.TestCase):
- def setUp(self):
+class TestUpdateWithReplayBuffer:
+ def setup_method(self):
config = GRPOWithReplayBufferConfig(
replay_buffer_size=5,
)
diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py
index da749abc62d..6e303ebbe83 100644
--- a/tests/test_kto_trainer.py
+++ b/tests/test_kto_trainer.py
@@ -27,8 +27,7 @@
class KTOTrainerTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_modeling_geometric_mixture_wrapper.py b/tests/test_modeling_geometric_mixture_wrapper.py
index 65553b79b77..7dcd89f757e 100644
--- a/tests/test_modeling_geometric_mixture_wrapper.py
+++ b/tests/test_modeling_geometric_mixture_wrapper.py
@@ -22,8 +22,7 @@
class TestGeometricMixtureWrapper(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py
index 0aa56901f82..a7bfd08eeb9 100644
--- a/tests/test_modeling_value_head.py
+++ b/tests/test_modeling_value_head.py
@@ -55,8 +55,7 @@ class VHeadModelTester(TrlTestCase):
trl_model_class = None
transformers_model_class = None
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def test_value_head(self):
@@ -189,10 +188,9 @@ class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase):
trl_model_class = AutoModelForCausalLMWithValueHead
transformers_model_class = AutoModelForCausalLM
- def tearDown(self):
+ def teardown_method(self):
# free memory
gc.collect()
- super().tearDown()
def test_inference(self):
r"""
@@ -303,10 +301,9 @@ class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase):
trl_model_class = AutoModelForSeq2SeqLMWithValueHead
transformers_model_class = AutoModelForSeq2SeqLM
- def tearDown(self):
+ def teardown_method(self):
# free memory
gc.collect()
- super().tearDown()
def test_inference(self):
r"""
@@ -409,8 +406,7 @@ def test_transformers_bf16_kwargs(self):
class ReferenceModelTest(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-internal-testing/tiny-GPT2LMHeadModel")
self.test_input = torch.tensor([[0, 1, 2, 3]])
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1)
diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py
index 25caaff3d38..90df42b48a7 100644
--- a/tests/test_nash_md_trainer.py
+++ b/tests/test_nash_md_trainer.py
@@ -29,8 +29,7 @@
class TestNashMDTrainer(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py
index 54e00447593..4552edafaef 100644
--- a/tests/test_online_dpo_trainer.py
+++ b/tests/test_online_dpo_trainer.py
@@ -36,8 +36,7 @@
class TestOnlineDPOTrainer(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py
index 2f444eb4dfc..8ab8819675a 100644
--- a/tests/test_orpo_trainer.py
+++ b/tests/test_orpo_trainer.py
@@ -26,8 +26,7 @@
class ORPOTrainerTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py
index ac62174db00..b8bacb7cce2 100644
--- a/tests/test_peft_models.py
+++ b/tests/test_peft_models.py
@@ -33,8 +33,7 @@
@require_peft
class PeftModelTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.causal_lm_model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.lora_config = LoraConfig(
r=16,
diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py
index 13354c36b29..c4c69d24e10 100644
--- a/tests/test_ppo_trainer.py
+++ b/tests/test_ppo_trainer.py
@@ -30,8 +30,7 @@
class TestPPOTrainer(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
# Set up the models and tokenizer using the test model
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py
index 2a8083d7492..f435c1295ab 100644
--- a/tests/test_prm_trainer.py
+++ b/tests/test_prm_trainer.py
@@ -31,8 +31,7 @@
class TestTokenizeRow(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
# Set up the mock tokenizer with specific behaviors
self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
self.tokenizer.bos_token_id = 0
@@ -206,8 +205,7 @@ def test_tokenize_row_multi_token_separator(self):
class PRMTrainerTester(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForTokenClassification.from_pretrained(model_id)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
diff --git a/tests/test_rich_progress_callback.py b/tests/test_rich_progress_callback.py
index d9069481263..d246b694b72 100644
--- a/tests/test_rich_progress_callback.py
+++ b/tests/test_rich_progress_callback.py
@@ -34,8 +34,7 @@ def forward(self, x):
@require_rich
class TestRichProgressCallback(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.dummy_model = DummyModel()
self.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5)
self.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 0dfa00d06e9..d663328e244 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -173,8 +173,7 @@ def test_create_peft_config_use_peft_true(self):
class TestDecodeAndStripPadding(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
def test_example_with_padding(self):
@@ -235,8 +234,7 @@ def test_val_none(self):
class TestDataCollatorForChatML(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
# Initialize the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
if self.tokenizer.pad_token is None:
@@ -322,8 +320,7 @@ def test_data_collator_for_chatml(self):
class TestBatchGeneration(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
# Initialize the tokenizer
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py
index 62d22583091..23c2080289c 100644
--- a/tests/test_vllm_client_server.py
+++ b/tests/test_vllm_client_server.py
@@ -57,7 +57,7 @@ class TestVLLMClientServer(TrlTestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
@@ -115,9 +115,7 @@ def test_reset_prefix_cache(self):
self.client.reset_prefix_cache()
@classmethod
- def tearDownClass(cls):
- super().tearDownClass()
-
+ def teardown_class(cls):
# Close the client
cls.client.close_communicator()
@@ -133,7 +131,7 @@ class TestVLLMClientServerBaseURL(TrlTestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
@@ -191,9 +189,7 @@ def test_reset_prefix_cache(self):
self.client.reset_prefix_cache()
@classmethod
- def tearDownClass(cls):
- super().tearDownClass()
-
+ def teardown_class(cls):
# Close the client
cls.client.close_communicator()
@@ -208,7 +204,7 @@ class TestVLLMClientServerTP(TrlTestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
# We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
@@ -249,9 +245,7 @@ def test_reset_prefix_cache(self):
self.client.reset_prefix_cache()
@classmethod
- def tearDownClass(cls):
- super().tearDownClass()
-
+ def teardown_class(cls):
# Close the client
cls.client.close_communicator()
@@ -266,7 +260,7 @@ class TestVLLMClientServerDP(TrlTestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
# We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
@@ -307,9 +301,7 @@ def test_reset_prefix_cache(self):
self.client.reset_prefix_cache()
@classmethod
- def tearDownClass(cls):
- super().tearDownClass()
-
+ def teardown_class(cls):
# Close the client
cls.client.close_communicator()
@@ -326,7 +318,7 @@ class TestVLLMClientServerDeviceParameter(TrlTestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
@@ -380,9 +372,7 @@ def test_init_communicator_with_torch_device(self):
client.close_communicator()
@classmethod
- def tearDownClass(cls):
- super().tearDownClass()
-
+ def teardown_class(cls):
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
kill_process(cls.server_process)
diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py
index 7a69455df33..bbf51bf83a8 100644
--- a/tests/test_xpo_trainer.py
+++ b/tests/test_xpo_trainer.py
@@ -29,8 +29,7 @@
class TestXPOTrainer(TrlTestCase):
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
From 8c30eec2f7c6878b7a7eed26f8f20d0b31863722 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 13:53:44 +0200
Subject: [PATCH 08/16] Replace unittest.skipUnless
---
tests/testing_utils.py | 81 ++++++++----------------------------------
1 file changed, 14 insertions(+), 67 deletions(-)
diff --git a/tests/testing_utils.py b/tests/testing_utils.py
index 61a7935c544..fdc91ccab81 100644
--- a/tests/testing_utils.py
+++ b/tests/testing_utils.py
@@ -15,7 +15,6 @@
import functools
import random
import signal
-import unittest
import warnings
import psutil
@@ -29,72 +28,20 @@
from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available
-# transformers.testing_utils contains a require_bitsandbytes function, but relies on pytest markers which we don't use
-# in our test suite. We therefore need to implement our own version of this function.
-def require_bitsandbytes(test_case):
- """
- Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available.
- """
- return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
-
-
-def require_comet(test_case):
- """
- Decorator marking a test that requires Comet. Skips the test if Comet is not available.
- """
- return unittest.skipUnless(is_comet_available(), "test requires comet_ml")(test_case)
-
-
-def require_llm_blender(test_case):
- """
- Decorator marking a test that requires llm-blender. Skips the test if llm-blender is not available.
- """
- return unittest.skipUnless(is_llm_blender_available(), "test requires llm-blender")(test_case)
-
-
-def require_mergekit(test_case):
- """
- Decorator marking a test that requires mergekit. Skips the test if mergekit is not available.
- """
- return unittest.skipUnless(is_mergekit_available(), "test requires mergekit")(test_case)
-
-
-def require_rich(test_case):
- """
- Decorator marking a test that requires rich. Skips the test if rich is not available.
- """
- return unittest.skipUnless(is_rich_available(), "test requires rich")(test_case)
-
-
-def require_sklearn(test_case):
- """
- Decorator marking a test that requires sklearn. Skips the test if sklearn is not available.
- """
- return unittest.skipUnless(is_sklearn_available() and is_joblib_available(), "test requires sklearn")(test_case)
-
-
-def require_vllm(test_case):
- """
- Decorator marking a test that requires vllm. Skips the test if vllm is not available.
- """
- return unittest.skipUnless(is_vllm_available(), "test requires vllm")(test_case)
-
-
-def require_no_wandb(test_case):
- """
- Decorator marking a test that requires no wandb. Skips the test if wandb is available.
- """
- return unittest.skipUnless(not is_wandb_available(), "test requires no wandb")(test_case)
-
-
-def require_3_accelerators(test_case):
- """
- Decorator marking a test that requires at least 3 accelerators. Skips the test if 3 accelerators are not available.
- """
- torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
- return unittest.skipUnless(
- torch_accelerator_module.device_count() >= 3, f"test requires at least 3 {torch_device}s"
- )(test_case)
+require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
+require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
+require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
+require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
+require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")
+require_sklearn = pytest.mark.skipif(
+ not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn"
+)
+require_vllm = pytest.mark.skipif(not is_vllm_available(), reason="test requires vllm")
+require_no_wandb = pytest.mark.skipif(is_wandb_available(), reason="test requires no wandb")
+require_3_accelerators = pytest.mark.skipif(
+ not (getattr(torch, torch_device, torch.cuda).device_count() >= 3),
+ reason=f"test requires at least 3 {torch_device}s",
+)
class RandomBinaryJudge(BaseBinaryJudge):
From 4630ed4ae57645742c023fa53cd102426dadc1db Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 14:03:25 +0200
Subject: [PATCH 09/16] Remove remaining TestCase
---
tests/test_cli_utils.py | 3 +--
tests/test_data_utils.py | 2 +-
tests/test_rewards.py | 2 +-
3 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py
index 708640c55fd..a7cda2fddf8 100644
--- a/tests/test_cli_utils.py
+++ b/tests/test_cli_utils.py
@@ -13,7 +13,6 @@
# limitations under the License.
import tempfile
-import unittest
from dataclasses import dataclass
from unittest.mock import mock_open, patch
@@ -268,7 +267,7 @@ def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load):
assert result_args[0].arg2 == "config_value" # Default from config
-class TestGetDataset(unittest.TestCase):
+class TestGetDataset:
def test_single_dataset_with_config(self):
mixture_config = DatasetMixtureConfig(
datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")]
diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py
index abb27258e3f..dcb10a59537 100644
--- a/tests/test_data_utils.py
+++ b/tests/test_data_utils.py
@@ -40,7 +40,7 @@
from .testing_utils import TrlTestCase
-class PrepareMultimodalMessagesTester(unittest.TestCase):
+class TestPrepareMultimodalMessages:
def test_basic_user_assistant_conversation(self):
"""Test basic conversation with user and assistant messages."""
messages = [
diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index 21827b6b4ea..46fddb5b19d 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -63,7 +63,7 @@ def test_mixed_format(self):
assert rewards == expected_rewards
-class SoftOverlongPunishmentRewardTester(unittest.TestCase):
+class TestSoftOverlongPunishmentReward:
def test_soft_overlong_punishment_short_completion(self):
"""Test soft overlong punishment reward function with a short completion."""
# length 50, with max=100 and soft cache=20, reward should be 0.
From d9e8bf51e87365d01f862699d1b987529860b727 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 14:21:39 +0200
Subject: [PATCH 10/16] Fix naming of tests
---
tests/slow/test_dpo_slow.py | 2 +-
tests/slow/test_grpo_slow.py | 2 +-
tests/slow/test_sft_slow.py | 2 +-
tests/test_bco_trainer.py | 2 +-
tests/test_best_of_n_sampler.py | 2 +-
tests/test_callbacks.py | 8 ++++----
tests/test_core.py | 2 +-
tests/test_cpo_trainer.py | 2 +-
tests/test_data_utils.py | 12 ++++++------
tests/test_dataset_formatting.py | 6 +++---
tests/test_dpo_trainer.py | 4 ++--
tests/test_gkd_trainer.py | 4 ++--
tests/test_grpo_trainer.py | 6 +++---
tests/test_kto_trainer.py | 2 +-
tests/test_modeling_value_head.py | 8 ++++----
tests/test_online_dpo_trainer.py | 2 +-
tests/test_orpo_trainer.py | 2 +-
tests/test_peft_models.py | 2 +-
tests/test_prm_trainer.py | 2 +-
tests/test_reward_trainer.py | 2 +-
tests/test_rewards.py | 2 +-
tests/test_rloo_trainer.py | 2 +-
tests/test_sft_trainer.py | 4 ++--
tests/test_trainers_args.py | 2 +-
tests/test_utils.py | 12 ++++++------
25 files changed, 48 insertions(+), 48 deletions(-)
diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py
index 30f7ffacb18..498c71179ac 100644
--- a/tests/slow/test_dpo_slow.py
+++ b/tests/slow/test_dpo_slow.py
@@ -37,7 +37,7 @@
@pytest.mark.slow
@require_torch_accelerator
@require_peft
-class DPOTrainerSlowTester(TrlTestCase):
+class TestDPOTrainerSlow(TrlTestCase):
def setup_method(self):
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
self.peft_config = LoraConfig(
diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py
index 5c1164fa420..1e594d0aedd 100644
--- a/tests/slow/test_grpo_slow.py
+++ b/tests/slow/test_grpo_slow.py
@@ -54,7 +54,7 @@
@pytest.mark.slow
@require_torch_accelerator
-class GRPOTrainerSlowTester(TrlTestCase):
+class TestGRPOTrainerSlow(TrlTestCase):
def setup_method(self):
self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test")
diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py
index d0811f5db40..114a9069943 100755
--- a/tests/slow/test_sft_slow.py
+++ b/tests/slow/test_sft_slow.py
@@ -44,7 +44,7 @@
@pytest.mark.slow
@require_torch_accelerator
@require_peft
-class SFTTrainerSlowTester(TrlTestCase):
+class TestSFTTrainerSlow(TrlTestCase):
def setup_method(self):
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py
index 91f905eb628..07ce1178837 100644
--- a/tests/test_bco_trainer.py
+++ b/tests/test_bco_trainer.py
@@ -33,7 +33,7 @@
from peft import LoraConfig
-class BCOTrainerTester(TrlTestCase):
+class TestBCOTrainer(TrlTestCase):
@parameterized.expand(
[
("standard_preference",),
diff --git a/tests/test_best_of_n_sampler.py b/tests/test_best_of_n_sampler.py
index cf6810976ec..d52538c71d0 100644
--- a/tests/test_best_of_n_sampler.py
+++ b/tests/test_best_of_n_sampler.py
@@ -27,7 +27,7 @@ def queries_to_scores(list_of_strings):
return [torch.rand(1).item() for _ in list_of_strings]
-class BestOfNSamplerTester(TrlTestCase):
+class TestBestOfNSampler(TrlTestCase):
"""
Tests the BestOfNSampler class
"""
diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py
index 448e11bed8f..c41d68b7efa 100644
--- a/tests/test_callbacks.py
+++ b/tests/test_callbacks.py
@@ -66,7 +66,7 @@ def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processi
self.ref_model = ref_model
-class WinRateCallbackTester(TrlTestCase):
+class TestWinRateCallback(TrlTestCase):
def setup_method(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -224,7 +224,7 @@ def test_lora(self):
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
-class LogCompletionsCallbackTester(TrlTestCase):
+class TestLogCompletionsCallback(TrlTestCase):
def setup_method(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -318,7 +318,7 @@ def test_basic_comet(self):
@require_mergekit
-class MergeModelCallbackTester(TrlTestCase):
+class TestMergeModelCallback(TrlTestCase):
def setup_method(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -374,7 +374,7 @@ def test_every_checkpoint(self):
assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}."
-class BEMACallbackTester(TrlTestCase):
+class TestBEMACallback(TrlTestCase):
def setup_method(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
diff --git a/tests/test_core.py b/tests/test_core.py
index 45449e5815b..85d99615be9 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -20,7 +20,7 @@
from .testing_utils import TrlTestCase
-class CoreTester(TrlTestCase):
+class TestCore(TrlTestCase):
"""
A wrapper class for testing core utils functions
"""
diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py
index 4f36be3bc84..d2c926a6735 100644
--- a/tests/test_cpo_trainer.py
+++ b/tests/test_cpo_trainer.py
@@ -25,7 +25,7 @@
from .testing_utils import TrlTestCase
-class CPOTrainerTester(TrlTestCase):
+class TestCPOTrainer(TrlTestCase):
def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py
index dcb10a59537..324680e6331 100644
--- a/tests/test_data_utils.py
+++ b/tests/test_data_utils.py
@@ -147,7 +147,7 @@ def test_mixed_prepared_and_unprepared_messages(self):
assert messages == expected
-class IsConversationalTester(TrlTestCase):
+class TestIsConversational(TrlTestCase):
conversational_examples = [
{ # Language modeling
"messages": [
@@ -257,7 +257,7 @@ def test_non_conversational(self, example):
assert not is_conversational(example)
-class IsConversationalFromValueTester(TrlTestCase):
+class TestIsConversationalFromValue(TrlTestCase):
def test_positive_1(self):
example = {
"conversations": [
@@ -281,7 +281,7 @@ def test_negative_2(self):
assert not is_conversational_from_value(example)
-class ApplyChatTemplateTester(TrlTestCase):
+class TestApplyChatTemplate(TrlTestCase):
tokenizers = [
"trl-internal-testing/tiny-CohereForCausalLM",
"trl-internal-testing/tiny-DbrxForCausalLM",
@@ -429,7 +429,7 @@ def get_current_temperature(location: str):
assert "get_current_temperature" not in result_without_tools["prompt"]
-class ApplyChatTemplateHarmonyTester(TrlTestCase):
+class TestApplyChatTemplateHarmony(TrlTestCase):
def test_language_modeling(self):
messages = {
"messages": [
@@ -655,7 +655,7 @@ def test_unpaired_preference(self):
assert output["label"]
-class UnpairPreferenceDatasetTester(TrlTestCase):
+class TestUnpairPreferenceDataset(TrlTestCase):
paired_dataset = Dataset.from_dict(
{
"prompt": ["The sky is", "The sun is"],
@@ -717,7 +717,7 @@ def test_maybe_unpair_preference_dataset_dict_already_paired(self):
)
-class ExtractPromptTester(TrlTestCase):
+class TestExtractPrompt(TrlTestCase):
example_implicit_prompt_conversational = {
"chosen": [
{"role": "user", "content": "What color is the sky?"},
diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py
index 49b9b460a86..80f65f964de 100644
--- a/tests/test_dataset_formatting.py
+++ b/tests/test_dataset_formatting.py
@@ -23,7 +23,7 @@
from .testing_utils import TrlTestCase
-class DatasetFormattingTestCase(TrlTestCase):
+class TestDatasetFormatting(TrlTestCase):
def setup_method(self):
self.llama_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-MistralForCausalLM-0.1")
self.chatml_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -116,7 +116,7 @@ def test_get_formatting_func_from_dataset_with_unknown_format(self):
assert formatting_func is None
-class SetupChatFormatTestCase(TrlTestCase):
+class TestSetupChatFormat(TrlTestCase):
def setup_method(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@@ -156,7 +156,7 @@ def test_example_with_setup_model(self):
)
-class CloneChatTemplateTestCase(TrlTestCase):
+class TestCloneChatTemplate(TrlTestCase):
def test_clone(self):
# This tokenizer doesn't have a chat_template by default
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index db0f96e9fb2..c7e5ebac50b 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -152,7 +152,7 @@ def test_tokenize_row_with_truncation_and_special_tokens(self):
}
-class DPOTrainerTester(TrlTestCase):
+class TestDPOTrainer(TrlTestCase):
def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
@@ -1406,7 +1406,7 @@ def test_train_with_iterable_dataset(self):
@require_vision
-class DPOVisionTrainerTester(TrlTestCase):
+class TestDPOVisionTrainer(TrlTestCase):
@parameterized.expand(
[
# ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975
diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py
index a1b62cd993e..b311ce2b0b6 100644
--- a/tests/test_gkd_trainer.py
+++ b/tests/test_gkd_trainer.py
@@ -27,7 +27,7 @@
from .testing_utils import TrlTestCase
-class TestGKDTrainer(TrlTestCase):
+class TestGKDTrainerGenerateOnPolicy(TrlTestCase):
@classmethod
def setup_class(cls):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
@@ -198,7 +198,7 @@ def test_zero_loss_for_identical_inputs(self):
assert round(abs(loss.item() - 0), 6) == 0
-class GKDTrainerTester(TrlTestCase):
+class TestGKDTrainer(TrlTestCase):
def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index 30a99199ee9..3b17b3554b7 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -43,7 +43,7 @@
from peft import LoraConfig, PeftModel
-class GetHighEntropyMaskTester(TrlTestCase):
+class TestGetHighEntropyMask(TrlTestCase):
def get_high_entropy_mask(self, entropies, mask, threshold):
"""Helper method to test the get_high_entropy_mask functionality."""
# Create a mock trainer with minimal setup
@@ -115,7 +115,7 @@ def test_compute_entropy_all_masked(self):
torch.testing.assert_close(entropy_mask, expected_mask)
-class GRPOTrainerTester(TrlTestCase):
+class TestGRPOTrainer(TrlTestCase):
def test_init_minimal(self):
# Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1975,7 +1975,7 @@ def custom_reward_func(completions, **kwargs):
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
-class GSPOTokenTrainerTester(TrlTestCase):
+class TestGSPOTokenTrainer(TrlTestCase):
def test_training(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py
index 6e303ebbe83..7cf9466eafa 100644
--- a/tests/test_kto_trainer.py
+++ b/tests/test_kto_trainer.py
@@ -26,7 +26,7 @@
from .testing_utils import TrlTestCase, require_no_wandb
-class KTOTrainerTester(TrlTestCase):
+class TestKTOTrainer(TrlTestCase):
def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py
index a7bfd08eeb9..65716cde99f 100644
--- a/tests/test_modeling_value_head.py
+++ b/tests/test_modeling_value_head.py
@@ -50,7 +50,7 @@
class BaseTester:
- class VHeadModelTester(TrlTestCase):
+ class TestVHeadModel(TrlTestCase):
all_model_names = None
trl_model_class = None
transformers_model_class = None
@@ -179,7 +179,7 @@ def test_from_save_transformers(self):
)
-class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase):
+class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
"""
Testing suite for v-head models.
"""
@@ -292,7 +292,7 @@ def test_push_to_hub(self):
)
-class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase):
+class TestSeq2SeqValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
"""
Testing suite for v-head models.
"""
@@ -405,7 +405,7 @@ def test_transformers_bf16_kwargs(self):
_ = trl_model(input_ids=dummy_input, decoder_input_ids=dummy_input)
-class ReferenceModelTest(TrlTestCase):
+class TestReferenceModel(TrlTestCase):
def setup_method(self):
self.model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-internal-testing/tiny-GPT2LMHeadModel")
self.test_input = torch.tensor([[0, 1, 2, 3]])
diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py
index 4552edafaef..981ed9dc47f 100644
--- a/tests/test_online_dpo_trainer.py
+++ b/tests/test_online_dpo_trainer.py
@@ -482,7 +482,7 @@ def simple_reward_func(prompts, completions, completion_ids, **kwargs):
@require_vision
-class OnlineDPOVisionTrainerTester(TrlTestCase):
+class TestOnlineDPOVisionTrainer(TrlTestCase):
@parameterized.expand(
[
("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",),
diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py
index 8ab8819675a..07adcd0d651 100644
--- a/tests/test_orpo_trainer.py
+++ b/tests/test_orpo_trainer.py
@@ -25,7 +25,7 @@
from .testing_utils import TrlTestCase
-class ORPOTrainerTester(TrlTestCase):
+class TestORPOTrainer(TrlTestCase):
def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py
index b8bacb7cce2..9d946db923c 100644
--- a/tests/test_peft_models.py
+++ b/tests/test_peft_models.py
@@ -32,7 +32,7 @@
@require_peft
-class PeftModelTester(TrlTestCase):
+class TestPeftModel(TrlTestCase):
def setup_method(self):
self.causal_lm_model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.lora_config = LoraConfig(
diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py
index f435c1295ab..52f4da21df2 100644
--- a/tests/test_prm_trainer.py
+++ b/tests/test_prm_trainer.py
@@ -204,7 +204,7 @@ def test_tokenize_row_multi_token_separator(self):
}
-class PRMTrainerTester(TrlTestCase):
+class TestPRMTrainer(TrlTestCase):
def setup_method(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForTokenClassification.from_pretrained(model_id)
diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py
index ba8300c78ff..9f86c8a9111 100644
--- a/tests/test_reward_trainer.py
+++ b/tests/test_reward_trainer.py
@@ -108,7 +108,7 @@ def test_collate_with_margin(self):
torch.testing.assert_close(result["margin"], torch.tensor([0.1, 0.2]))
-class RewardTrainerTester(TrlTestCase):
+class TestRewardTrainer(TrlTestCase):
@parameterized.expand(
[
("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",),
diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index 46fddb5b19d..0f584d1b58a 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -19,7 +19,7 @@
from .testing_utils import TrlTestCase
-class ThinkFormatRewardTester(TrlTestCase):
+class TestThinkFormatReward(TrlTestCase):
def test_valid_format(self):
completions = [
"This is my reasoning.This is my answer.", # Simple, one-line reasoning
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index 23f709d466a..18b402bd940 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -37,7 +37,7 @@
from peft import LoraConfig, PeftModel
-class RLOOTrainerTester(TrlTestCase):
+class TestRLOOTrainer(TrlTestCase):
def test_init_minimal(self):
# Test that RLOOTrainer can be instantiated with only model, reward_model and train_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index b60ea48bfaa..d5cbcc07b76 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -33,7 +33,7 @@
from peft import LoraConfig, PeftModel, PromptEncoderConfig, TaskType, get_peft_model
-class DFTLossTester(TrlTestCase):
+class TestDFTLoss(TrlTestCase):
def test_dft_loss(self):
batch_size = 2
seq_len = 3
@@ -239,7 +239,7 @@ def test_multiple_examples(self):
assert torch.equal(result[1], torch.arange(3))
-class SFTTrainerTester(TrlTestCase):
+class TestSFTTrainer(TrlTestCase):
@parameterized.expand(
[
("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",),
diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py
index b9a809ba071..b76110d5f17 100644
--- a/tests/test_trainers_args.py
+++ b/tests/test_trainers_args.py
@@ -44,7 +44,7 @@
from .testing_utils import TrlTestCase, require_sklearn
-class TrainerArgTester(TrlTestCase):
+class TestTrainerArg(TrlTestCase):
@require_sklearn
def test_bco(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d663328e244..9bf457b590e 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -534,7 +534,7 @@ def test_no_tensors(self):
assert torch.equal(new_mask, expected_mask)
-class RepeatRandomSamplerTester(TrlTestCase):
+class TestRepeatRandomSampler(TrlTestCase):
def test_sampler(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatSampler(dataset, mini_repeat_count=2)
@@ -841,7 +841,7 @@ def test_selective_log_softmax(self, dtype):
torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)
-class ShuffleSequenceDictTester(TrlTestCase):
+class TestShuffleSequenceDict(TrlTestCase):
def test_shuffle_preserves_shape(self):
x = torch.arange(6).reshape(3, 2)
y = torch.arange(3).reshape(3, 1)
@@ -906,7 +906,7 @@ def test_shuffle_with_list(self):
pytest.fail("Unexpected x row in shuffled output.")
-class SplitTensorDictTester(TrlTestCase):
+class TestSplitTensorDict(TrlTestCase):
def test_split_equal_chunks(self):
x = torch.arange(12).reshape(6, 2)
y = torch.arange(6).reshape(6, 1)
@@ -946,7 +946,7 @@ def test_with_scalar(self):
assert torch.equal(result[i]["y"], torch.tensor(1))
-class SplitPixelValuesByGridTester(TrlTestCase):
+class TestSplitPixelValuesByGrid(TrlTestCase):
def test_split_correctly_0(self):
batch = {
"image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]),
@@ -1010,7 +1010,7 @@ def test_multi_images(self):
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
-class TruncateWithProtectedTokensTester(TrlTestCase):
+class TestTruncateWithProtectedTokens(TrlTestCase):
def test_basic_example(self):
"""Test the basic example from the problem description."""
prompt_ids = [1, 2, 3, 4, 5]
@@ -1088,7 +1088,7 @@ def test_order_preservation(self):
assert new_ids == expected_ids
-class UnsplitPixelValuesByGridTester(TrlTestCase):
+class TestUnsplitPixelValuesByGrid(TrlTestCase):
def test_unsplit_correctly(self):
pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]
pixel_values_merged = torch.cat(pixel_values, dim=0)
From 01d1262eec50d15525075df4c71125bebc2972e4 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 14:59:55 +0200
Subject: [PATCH 11/16] Fix naming of subclass
---
tests/test_modeling_value_head.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py
index 65716cde99f..95445b31421 100644
--- a/tests/test_modeling_value_head.py
+++ b/tests/test_modeling_value_head.py
@@ -50,7 +50,7 @@
class BaseTester:
- class TestVHeadModel(TrlTestCase):
+ class VHeadModelTester(TrlTestCase):
all_model_names = None
trl_model_class = None
transformers_model_class = None
From 309e76d40f1f05f0138ab32b5a809be29979c6dc Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 15:12:03 +0200
Subject: [PATCH 12/16] Replace assertLogs
---
tests/test_grpo_trainer.py | 6 +++---
tests/test_rloo_trainer.py | 6 +++---
tests/test_utils.py | 6 +++---
3 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index 3b17b3554b7..80d5eb97ec3 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -1027,7 +1027,7 @@ def test_training_with_mask_truncated_completions_all_masked(self):
new_param = trainer.model.get_parameter(n)
assert torch.equal(param, new_param), f"Parameter {n} has changed."
- def test_warning_raised_all_rewards_none(self):
+ def test_warning_raised_all_rewards_none(self, caplog):
"""Test that a proper warning is raised when all rewards are None."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1050,11 +1050,11 @@ def always_none_reward_func(completions, **kwargs):
train_dataset=dataset,
)
- with self.assertLogs("trl.trainer.grpo_trainer", level="WARNING") as cm:
+ with caplog.at_level("WARNING", logger="trl.trainer.grpo_trainer"):
trainer.train()
expected_warning = "All reward functions returned None for the following kwargs:"
- assert expected_warning in cm.output[0]
+ assert expected_warning in caplog.text
def test_training_num_generations_larger_than_batch_size(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index 18b402bd940..a8554414dff 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -882,7 +882,7 @@ def test_training_with_mask_truncated_completions_all_masked(self):
new_param = trainer.model.get_parameter(n)
assert torch.equal(param, new_param), f"Parameter {n} has changed."
- def test_warning_raised_all_rewards_none(self):
+ def test_warning_raised_all_rewards_none(self, caplog):
"""Test that a proper warning is raised when all rewards are None."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -905,11 +905,11 @@ def always_none_reward_func(completions, **kwargs):
train_dataset=dataset,
)
- with self.assertLogs("trl.trainer.rloo_trainer", level="WARNING") as cm:
+ with caplog.at_level("WARNING", logger="trl.trainer.rloo_trainer"):
trainer.train()
expected_warning = "All reward functions returned None for the following kwargs:"
- assert expected_warning in cm.output[0]
+ assert expected_warning in caplog.text
def test_training_num_generations_larger_than_batch_size(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 9bf457b590e..cdedcf95ad9 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -422,7 +422,7 @@ def test_token_classification_task_with_ignored_tokens_1(self):
result = compute_accuracy(eval_pred)
assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0
- def test_rewards_comparison_task(self):
+ def test_rewards_comparison_task(self, caplog):
eval_pred = (
np.array(
[
@@ -435,7 +435,7 @@ def test_rewards_comparison_task(self):
)
expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored)
- with self.assertLogs("trl.trainer.utils", level="WARNING") as cm:
+ with caplog.at_level("WARNING", logger="trl.trainer.utils"):
result = compute_accuracy(eval_pred)
assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0
@@ -443,7 +443,7 @@ def test_rewards_comparison_task(self):
"There are 1 out of 3 instances where the predictions for both options are equal. "
"These instances are ignored in the accuracy computation."
)
- assert expected_warning in cm.output[0]
+ assert expected_warning in caplog.text
class TestFlushLeft(TrlTestCase):
From 0d7815cd60b63a12f0ea1c1eb7069f68a77b5640 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 15:50:32 +0200
Subject: [PATCH 13/16] Revert "Set CI for debugging"
This reverts commit e2d1e61316ced2066aeb8b95fd09809eba40f495.
---
.github/workflows/tests.yml | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 4c21a98f4f9..4231ef227ec 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -21,7 +21,7 @@ jobs:
check_code_quality:
name: Check code quality
runs-on: ubuntu-latest
- # if: github.event.pull_request.draft == false
+ if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
@@ -36,7 +36,7 @@ jobs:
name: Tests
strategy:
matrix:
- python-version: ['3.10']
+ python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
fail-fast: false
runs-on:
group: aws-g4dn-2xlarge
@@ -46,7 +46,7 @@ jobs:
defaults:
run:
shell: bash
- # if: github.event.pull_request.draft == false
+ if: github.event.pull_request.draft == false
steps:
- name: Git checkout
uses: actions/checkout@v4
From 9428f4f564774de9ef03857ec6113e05fa1fa165 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 19:32:18 +0200
Subject: [PATCH 14/16] Replace require_peft
---
tests/slow/test_dpo_slow.py | 4 ++--
tests/slow/test_grpo_slow.py | 3 +--
tests/slow/test_sft_slow.py | 3 +--
tests/test_activation_offloading.py | 4 ++--
tests/test_bco_trainer.py | 3 +--
tests/test_callbacks.py | 5 ++---
tests/test_cpo_trainer.py | 3 +--
tests/test_dpo_trainer.py | 3 +--
tests/test_grpo_trainer.py | 4 ++--
tests/test_kto_trainer.py | 4 ++--
tests/test_nash_md_trainer.py | 3 +--
tests/test_online_dpo_trainer.py | 4 ++--
tests/test_orpo_trainer.py | 3 +--
tests/test_peft_models.py | 7 ++-----
tests/test_ppo_trainer.py | 3 +--
tests/test_prm_trainer.py | 3 +--
tests/test_reward_trainer.py | 3 +--
tests/test_rloo_trainer.py | 4 ++--
tests/test_sft_trainer.py | 4 ++--
tests/test_utils.py | 3 +--
tests/test_xpo_trainer.py | 3 +--
tests/testing_utils.py | 3 ++-
22 files changed, 32 insertions(+), 47 deletions(-)
diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py
index 498c71179ac..26feb388c6b 100644
--- a/tests/slow/test_dpo_slow.py
+++ b/tests/slow/test_dpo_slow.py
@@ -21,12 +21,12 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
-from transformers.testing_utils import backend_empty_cache, require_peft, require_torch_accelerator, torch_device
+from transformers.testing_utils import backend_empty_cache, require_torch_accelerator, torch_device
from transformers.utils import is_peft_available
from trl import DPOConfig, DPOTrainer
-from ..testing_utils import TrlTestCase, require_bitsandbytes
+from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST
diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py
index 63eecdf1398..17798d10b11 100644
--- a/tests/slow/test_grpo_slow.py
+++ b/tests/slow/test_grpo_slow.py
@@ -35,7 +35,6 @@
backend_empty_cache,
require_flash_attn,
require_liger_kernel,
- require_peft,
require_torch_accelerator,
torch_device,
)
@@ -44,7 +43,7 @@
from trl import GRPOConfig, GRPOTrainer
from trl.trainer.utils import get_kbit_device_map
-from ..testing_utils import TrlTestCase, require_bitsandbytes, require_vllm
+from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft, require_vllm
from .testing_constants import MODELS_TO_TEST
diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py
index 114a9069943..b6928b697c7 100755
--- a/tests/slow/test_sft_slow.py
+++ b/tests/slow/test_sft_slow.py
@@ -24,7 +24,6 @@
from transformers.testing_utils import (
backend_empty_cache,
require_liger_kernel,
- require_peft,
require_torch_accelerator,
require_torch_multi_accelerator,
torch_device,
@@ -33,7 +32,7 @@
from trl import SFTConfig, SFTTrainer
-from ..testing_utils import TrlTestCase, require_bitsandbytes
+from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS
diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py
index 6c9ae24b8ae..d1a9ea921f5 100644
--- a/tests/test_activation_offloading.py
+++ b/tests/test_activation_offloading.py
@@ -16,12 +16,12 @@
import torch
from torch import nn
from transformers import AutoModelForCausalLM
-from transformers.testing_utils import require_peft, require_torch_accelerator, torch_device
+from transformers.testing_utils import require_torch_accelerator, torch_device
from transformers.utils import is_peft_available
from trl.models.activation_offloading import NoOpManager, OffloadActivations
-from .testing_utils import TrlTestCase
+from .testing_utils import TrlTestCase, require_peft
if is_peft_available():
diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py
index 30214a6f781..7b7f0414438 100644
--- a/tests/test_bco_trainer.py
+++ b/tests/test_bco_trainer.py
@@ -20,13 +20,12 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
-from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import BCOConfig, BCOTrainer
from trl.trainer.bco_trainer import _process_tokens, _tokenize
-from .testing_utils import TrlTestCase, require_no_wandb, require_sklearn
+from .testing_utils import TrlTestCase, require_no_wandb, require_peft, require_sklearn
if is_peft_available():
diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py
index c41d68b7efa..316c9b35ae1 100644
--- a/tests/test_callbacks.py
+++ b/tests/test_callbacks.py
@@ -18,11 +18,10 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
-from transformers.testing_utils import require_peft, require_wandb
+from transformers.testing_utils import require_wandb
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_peft_available
-from tests.testing_utils import require_comet, require_mergekit
from trl import (
BasePairwiseJudge,
BEMACallback,
@@ -34,7 +33,7 @@
)
from trl.mergekit_utils import MergeConfig
-from .testing_utils import TrlTestCase
+from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft
if is_peft_available():
diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py
index d2c926a6735..56792f608dc 100644
--- a/tests/test_cpo_trainer.py
+++ b/tests/test_cpo_trainer.py
@@ -17,12 +17,11 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
-from transformers.testing_utils import require_peft
from trl import CPOConfig, CPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
-from .testing_utils import TrlTestCase
+from .testing_utils import TrlTestCase, require_peft
class TestCPOTrainer(TrlTestCase):
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index 8455909c242..1dcecdb79f4 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -34,14 +34,13 @@
from transformers.testing_utils import (
get_device_properties,
require_liger_kernel,
- require_peft,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_vision,
)
from trl import DPOConfig, DPOTrainer, FDivergenceType
-from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb
+from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft
if is_vision_available():
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index e3fa06eeb58..1e504ffe4a8 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -25,7 +25,7 @@
AutoModelForSequenceClassification,
AutoTokenizer,
)
-from transformers.testing_utils import require_liger_kernel, require_peft, require_vision
+from transformers.testing_utils import require_liger_kernel, require_vision
from transformers.utils import is_peft_available
from trl import GRPOConfig, GRPOTrainer
@@ -36,7 +36,7 @@
)
from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer
-from .testing_utils import TrlTestCase, require_vllm
+from .testing_utils import TrlTestCase, require_peft, require_vllm
if is_peft_available():
diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py
index c51cf5a0342..e2c325149f2 100644
--- a/tests/test_kto_trainer.py
+++ b/tests/test_kto_trainer.py
@@ -18,12 +18,12 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
-from transformers.testing_utils import require_liger_kernel, require_peft
+from transformers.testing_utils import require_liger_kernel
from trl import KTOConfig, KTOTrainer
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
-from .testing_utils import TrlTestCase, require_no_wandb
+from .testing_utils import TrlTestCase, require_no_wandb, require_peft
class TestKTOTrainer(TrlTestCase):
diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py
index 90df42b48a7..d6026e73443 100644
--- a/tests/test_nash_md_trainer.py
+++ b/tests/test_nash_md_trainer.py
@@ -16,12 +16,11 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
-from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import NashMDConfig, NashMDTrainer
-from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender
+from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft
if is_peft_available():
diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py
index 33bc7bdd3b8..9d2e2ccf8da 100644
--- a/tests/test_online_dpo_trainer.py
+++ b/tests/test_online_dpo_trainer.py
@@ -18,12 +18,12 @@
from packaging.version import Version
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
-from transformers.testing_utils import require_peft, require_torch_accelerator, require_vision
+from transformers.testing_utils import require_torch_accelerator, require_vision
from transformers.utils import is_peft_available, is_vision_available
from trl import OnlineDPOConfig, OnlineDPOTrainer
-from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_vllm
+from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft, require_vllm
if is_peft_available():
diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py
index 07adcd0d651..dedfc4c36c9 100644
--- a/tests/test_orpo_trainer.py
+++ b/tests/test_orpo_trainer.py
@@ -17,12 +17,11 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
-from transformers.testing_utils import require_peft
from trl import ORPOConfig, ORPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
-from .testing_utils import TrlTestCase
+from .testing_utils import TrlTestCase, require_peft
class TestORPOTrainer(TrlTestCase):
diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py
index 9d946db923c..508ad175565 100644
--- a/tests/test_peft_models.py
+++ b/tests/test_peft_models.py
@@ -16,15 +16,12 @@
import torch
from transformers import AutoModelForCausalLM
-from transformers.testing_utils import (
- require_peft,
- require_torch_gpu_if_bnb_not_multi_backend_enabled,
-)
+from transformers.testing_utils import require_torch_gpu_if_bnb_not_multi_backend_enabled
from transformers.utils import is_peft_available
from trl import AutoModelForCausalLMWithValueHead
-from .testing_utils import TrlTestCase
+from .testing_utils import TrlTestCase, require_peft
if is_peft_available():
diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py
index c4c69d24e10..6e62e742115 100644
--- a/tests/test_ppo_trainer.py
+++ b/tests/test_ppo_trainer.py
@@ -16,13 +16,12 @@
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
-from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import PPOConfig, PPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
-from .testing_utils import TrlTestCase
+from .testing_utils import TrlTestCase, require_peft
if is_peft_available():
diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py
index 52f4da21df2..16876c6df62 100644
--- a/tests/test_prm_trainer.py
+++ b/tests/test_prm_trainer.py
@@ -18,12 +18,11 @@
from datasets import Dataset, load_dataset
from parameterized import parameterized
from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerBase
-from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import PRMConfig, PRMTrainer
-from .testing_utils import TrlTestCase
+from .testing_utils import TrlTestCase, require_peft
if is_peft_available():
diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py
index 9f86c8a9111..f5645ae445f 100644
--- a/tests/test_reward_trainer.py
+++ b/tests/test_reward_trainer.py
@@ -19,13 +19,12 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForSequenceClassification, AutoTokenizer
-from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import RewardConfig, RewardTrainer
from trl.trainer.reward_trainer import DataCollatorForPreference
-from .testing_utils import TrlTestCase
+from .testing_utils import TrlTestCase, require_peft
if is_peft_available():
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index 9c267674f47..de3c7c2a159 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -25,12 +25,12 @@
AutoModelForSequenceClassification,
AutoTokenizer,
)
-from transformers.testing_utils import require_peft, require_vision
+from transformers.testing_utils import require_vision
from transformers.utils import is_peft_available
from trl import RLOOConfig, RLOOTrainer
-from .testing_utils import TrlTestCase, require_vllm
+from .testing_utils import TrlTestCase, require_peft, require_vllm
if is_peft_available():
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index d5cbcc07b76..482c93df7e3 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -20,13 +20,13 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer
-from transformers.testing_utils import require_flash_attn, require_liger_kernel, require_peft, require_vision
+from transformers.testing_utils import require_flash_attn, require_liger_kernel, require_vision
from transformers.utils import is_peft_available
from trl import SFTConfig, SFTTrainer
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss
-from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes
+from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, require_peft
if is_peft_available():
diff --git a/tests/test_utils.py b/tests/test_utils.py
index cdedcf95ad9..14b3fb570af 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -22,7 +22,6 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
-from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import ModelConfig
@@ -47,7 +46,7 @@
unsplit_pixel_values_by_grid,
)
-from .testing_utils import TrlTestCase, require_rich
+from .testing_utils import TrlTestCase, require_peft, require_rich
if is_peft_available():
diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py
index e4b60a810cf..4d41471187c 100644
--- a/tests/test_xpo_trainer.py
+++ b/tests/test_xpo_trainer.py
@@ -16,12 +16,11 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
-from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import XPOConfig, XPOTrainer
-from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender
+from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft
if is_peft_available():
diff --git a/tests/testing_utils.py b/tests/testing_utils.py
index fdc91ccab81..64904678446 100644
--- a/tests/testing_utils.py
+++ b/tests/testing_utils.py
@@ -22,7 +22,7 @@
import torch
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
from transformers.testing_utils import torch_device
-from transformers.utils import is_rich_available
+from transformers.utils import is_peft_available, is_rich_available
from trl import BaseBinaryJudge, BasePairwiseJudge
from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available
@@ -32,6 +32,7 @@
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
+require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft")
require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")
require_sklearn = pytest.mark.skipif(
not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn"
From bc60f7619e658b9d87007a5edac01e3b3b0fc7ee Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Wed, 1 Oct 2025 19:38:59 +0200
Subject: [PATCH 15/16] Replace require_vision
---
tests/test_dpo_trainer.py | 3 +--
tests/test_grpo_trainer.py | 4 ++--
tests/test_online_dpo_trainer.py | 11 +++++++++--
tests/test_rloo_trainer.py | 3 +--
tests/test_sft_trainer.py | 4 ++--
tests/testing_utils.py | 3 ++-
6 files changed, 17 insertions(+), 11 deletions(-)
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index 1dcecdb79f4..aea42c8d8d0 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -35,12 +35,11 @@
get_device_properties,
require_liger_kernel,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
- require_vision,
)
from trl import DPOConfig, DPOTrainer, FDivergenceType
-from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft
+from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft, require_vision
if is_vision_available():
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index 1e504ffe4a8..d63520bea43 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -25,7 +25,7 @@
AutoModelForSequenceClassification,
AutoTokenizer,
)
-from transformers.testing_utils import require_liger_kernel, require_vision
+from transformers.testing_utils import require_liger_kernel
from transformers.utils import is_peft_available
from trl import GRPOConfig, GRPOTrainer
@@ -36,7 +36,7 @@
)
from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer
-from .testing_utils import TrlTestCase, require_peft, require_vllm
+from .testing_utils import TrlTestCase, require_peft, require_vision, require_vllm
if is_peft_available():
diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py
index 9d2e2ccf8da..f8706770371 100644
--- a/tests/test_online_dpo_trainer.py
+++ b/tests/test_online_dpo_trainer.py
@@ -18,12 +18,19 @@
from packaging.version import Version
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
-from transformers.testing_utils import require_torch_accelerator, require_vision
+from transformers.testing_utils import require_torch_accelerator
from transformers.utils import is_peft_available, is_vision_available
from trl import OnlineDPOConfig, OnlineDPOTrainer
-from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft, require_vllm
+from .testing_utils import (
+ RandomPairwiseJudge,
+ TrlTestCase,
+ require_llm_blender,
+ require_peft,
+ require_vision,
+ require_vllm,
+)
if is_peft_available():
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index de3c7c2a159..005cd064a4d 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -25,12 +25,11 @@
AutoModelForSequenceClassification,
AutoTokenizer,
)
-from transformers.testing_utils import require_vision
from transformers.utils import is_peft_available
from trl import RLOOConfig, RLOOTrainer
-from .testing_utils import TrlTestCase, require_peft, require_vllm
+from .testing_utils import TrlTestCase, require_peft, require_vision, require_vllm
if is_peft_available():
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index 482c93df7e3..5d1aacf2876 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -20,13 +20,13 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer
-from transformers.testing_utils import require_flash_attn, require_liger_kernel, require_vision
+from transformers.testing_utils import require_flash_attn, require_liger_kernel
from transformers.utils import is_peft_available
from trl import SFTConfig, SFTTrainer
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss
-from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, require_peft
+from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, require_peft, require_vision
if is_peft_available():
diff --git a/tests/testing_utils.py b/tests/testing_utils.py
index 64904678446..cbe677255b5 100644
--- a/tests/testing_utils.py
+++ b/tests/testing_utils.py
@@ -22,7 +22,7 @@
import torch
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
from transformers.testing_utils import torch_device
-from transformers.utils import is_peft_available, is_rich_available
+from transformers.utils import is_peft_available, is_rich_available, is_vision_available
from trl import BaseBinaryJudge, BasePairwiseJudge
from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available
@@ -37,6 +37,7 @@
require_sklearn = pytest.mark.skipif(
not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn"
)
+require_vision = pytest.mark.skipif(not is_vision_available(), reason="test requires vision")
require_vllm = pytest.mark.skipif(not is_vllm_available(), reason="test requires vllm")
require_no_wandb = pytest.mark.skipif(is_wandb_available(), reason="test requires no wandb")
require_3_accelerators = pytest.mark.skipif(
From bf326893823ff186250f3256e5e9454f3dd422b7 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Thu, 2 Oct 2025 07:56:26 +0200
Subject: [PATCH 16/16] Replace unittest skip and remove main
---
tests/test_cli.py | 10 +++-------
tests/test_data_utils.py | 6 ------
tests/test_dpo_trainer.py | 11 +++--------
tests/test_grpo_trainer.py | 15 +++++----------
tests/test_judges.py | 5 +++--
tests/test_modeling_value_head.py | 6 +++---
tests/test_reward_trainer.py | 4 ++--
tests/test_rewards.py | 5 -----
tests/test_rloo_trainer.py | 13 ++++---------
9 files changed, 23 insertions(+), 52 deletions(-)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 638e1d38493..48087f5054c 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -15,18 +15,18 @@
import os
import sys
-import unittest
from io import StringIO
from unittest.mock import patch
+import pytest
import yaml
from .testing_utils import TrlTestCase
-@unittest.skipIf(
+@pytest.mark.skipif(
sys.version_info < (3, 10),
- "Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests "
+ reason="Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests "
"to fail on Python <3.10.", # let's say it's a known issue, but not expected to be fixed, because too niche
)
class TestCLI(TrlTestCase):
@@ -113,7 +113,3 @@ def test_sft_config_file(self):
# Verify that output directory was created
assert os.path.exists(output_dir)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py
index 324680e6331..8fe8a24bd50 100644
--- a/tests/test_data_utils.py
+++ b/tests/test_data_utils.py
@@ -15,7 +15,6 @@
import copy
import itertools
import textwrap
-import unittest
from time import strftime
from datasets import Dataset, DatasetDict
@@ -977,8 +976,3 @@ def test_already_chatml(self):
]
}
assert maybe_convert_to_chatml(example) == example
-
-
-# Run the tests
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index aea42c8d8d0..1d1e94fcf99 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -14,7 +14,6 @@
import re
import sys
-import unittest
from unittest.mock import MagicMock
import numpy as np
@@ -714,9 +713,9 @@ def test_dpo_lora_bf16_autocast_llama(self):
)
@require_bitsandbytes
@require_peft
- @unittest.skipIf(
+ @pytest.mark.skipif(
get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8,
- "Skipping because bf16 not supported on CUDA GPU with capability < 8.0",
+ reason="Skipping because bf16 not supported on CUDA GPU with capability < 8.0",
)
def test_dpo_lora_bf16_autocast(self, loss_type, pre_compute, gen_during_eval):
from peft import LoraConfig
@@ -1301,7 +1300,7 @@ def test_train_with_length_desensitization(self):
]
)
@require_liger_kernel
- @unittest.skipUnless(sys.version_info >= (3, 10), "Liger kernel is not supported on Python 3.9")
+ @pytest.mark.skipif(not (sys.version_info >= (3, 10)), reason="Liger kernel is not supported on Python 3.9")
def test_dpo_trainer_with_liger(self, beta, loss_type):
"""Test DPO trainer with Liger loss enabled across supported loss types.
@@ -1512,7 +1511,3 @@ def test_f_divergence_type(self, f_divergence_type, as_string: bool):
# Serialization: TrainingArguments.to_dict should yield the enum's string value
configparser_dict = training_args.to_dict()
assert configparser_dict["f_divergence_type"] == f_divergence_type.value
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index d63520bea43..fb7e76f1594 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import unittest
from unittest.mock import patch
import pytest
@@ -720,9 +719,9 @@ def test_training_with_entropy_filter(self):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
- @unittest.skip("We should add a mock for the vLLM server.")
@require_peft
@require_vllm
+ @pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_and_peft(self):
"""Test that training works with vLLM for generation."""
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM
@@ -767,7 +766,7 @@ def test_training_vllm_and_peft(self):
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
@require_vllm
- @unittest.skip("We should add a mock for the vLLM server.")
+ @pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_guided_decoding(self):
"""Test that training works with vLLM for generation with guided decoding."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -801,7 +800,7 @@ def test_training_vllm_guided_decoding(self):
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vllm
- @unittest.skip("We should add a mock for the vLLM server.")
+ @pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_importance_sampling_correction(self):
"""Test that training works with vLLM for generation with guided decoding."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -871,7 +870,7 @@ def test_training_with_additional_generation_kwargs(self):
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vllm
- @unittest.skip("We should add a mock for the vLLM server.")
+ @pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_with_additional_generation_kwargs(self):
"""Test that training works with vLLM and additional generation kwargs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1521,7 +1520,7 @@ def reward_func(completions, **kwargs):
)
@require_vision
@require_vllm
- @unittest.skip("We should add a mock for the vLLM server.")
+ @pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vlm_and_vllm(self, model_id) -> None:
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
@@ -2006,7 +2005,3 @@ def test_training(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_judges.py b/tests/test_judges.py
index 9238a7f5cb2..bba1ffffbbb 100644
--- a/tests/test_judges.py
+++ b/tests/test_judges.py
@@ -13,7 +13,8 @@
# limitations under the License.
import time
-import unittest
+
+import pytest
from trl import AllTrueJudge, HfPairwiseJudge, PairRMJudge
@@ -38,7 +39,7 @@ def test_all_true_judge(self):
assert len(judgements) == 2
assert all(judgement in {0, 1, -1} for judgement in judgements)
- @unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.")
+ @pytest.mark.skip(reason="This test needs to be run manually since it requires a valid Hugging Face API key.")
def test_hugging_face_judge(self):
judge = HfPairwiseJudge()
prompts, completions = self._get_prompts_and_pairwise_completions()
diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py
index 95445b31421..a2fde5a12c4 100644
--- a/tests/test_modeling_value_head.py
+++ b/tests/test_modeling_value_head.py
@@ -13,8 +13,8 @@
# limitations under the License.
import gc
-import unittest
+import pytest
import torch
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig
@@ -273,7 +273,7 @@ def test_transformers_bf16_kwargs(self):
# check dummy forward pass works in half precision
_ = trl_model(dummy_input)
- @unittest.skip("This test needs to be run manually due to HF token issue.")
+ @pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.")
def test_push_to_hub(self):
for model_name in self.all_model_names:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
@@ -364,7 +364,7 @@ def test_generate(self, model_name):
# Just check if the generation works
_ = model.generate(input_ids, decoder_input_ids=decoder_input_ids, generation_config=generation_config)
- @unittest.skip("This test needs to be run manually due to HF token issue.")
+ @pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.")
def test_push_to_hub(self):
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py
index f5645ae445f..ab6d6656e99 100644
--- a/tests/test_reward_trainer.py
+++ b/tests/test_reward_trainer.py
@@ -13,8 +13,8 @@
# limitations under the License.
import pathlib
-import unittest
+import pytest
import torch
from datasets import load_dataset
from parameterized import parameterized
@@ -653,7 +653,7 @@ def test_train_with_set_chat_template_from_path(self):
original_template_content = f.read()
assert template_content == original_template_content, "Chat template content does not match the original"
- @unittest.skip("Skipping until we have a dataset with tool calls")
+ @pytest.mark.skip(reason="Skipping until we have a dataset with tool calls")
def test_train_toolcall_data(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/toolcall", split="train")
diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index 0f584d1b58a..0764ce5d9ea 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import unittest
from trl.rewards import get_soft_overlong_punishment, think_format_reward
@@ -86,7 +85,3 @@ def test_soft_overlong_punishment_intermediate_completion(self):
completion_ids = [[1] * 90] # 90 is between 80 and 100
rewards = reward_fn(completion_ids)
assert round(abs(rewards[0] - -0.5), 4) == 0
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index 005cd064a4d..1de4eca479e 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import unittest
from unittest.mock import patch
import pytest
@@ -580,9 +579,9 @@ def test_training_beta_zero(self):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
- @unittest.skip("We should add a mock for the vLLM server.")
@require_peft
@require_vllm
+ @pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_and_peft(self):
"""Test that training works with vLLM for generation."""
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM
@@ -627,7 +626,7 @@ def test_training_vllm_and_peft(self):
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
@require_vllm
- @unittest.skip("We should add a mock for the vLLM server.")
+ @pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_guided_decoding(self):
"""Test that training works with vLLM for generation with guided decoding."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -696,7 +695,7 @@ def test_training_with_additional_generation_kwargs(self):
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vllm
- @unittest.skip("We should add a mock for the vLLM server.")
+ @pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_with_additional_generation_kwargs(self):
"""Test that training works with vLLM and additional generation kwargs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -1262,7 +1261,7 @@ def reward_func(completions, **kwargs):
)
@require_vision
@require_vllm
- @unittest.skip("We should add a mock for the vLLM server.")
+ @pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vlm_and_vllm(self, model_id) -> None:
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
@@ -1416,7 +1415,3 @@ def test_single_reward_model_with_single_processing_class(self):
assert len(trainer.reward_processing_classes) == 1
assert trainer.reward_processing_classes[0] == single_processing_class
-
-
-if __name__ == "__main__":
- unittest.main()