Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.testing_utils import backend_empty_cache, require_peft, require_torch_accelerator, torch_device
from transformers.testing_utils import backend_empty_cache, require_torch_accelerator, torch_device
from transformers.utils import is_peft_available

from trl import DPOConfig, DPOTrainer

from ..testing_utils import TrlTestCase, require_bitsandbytes
from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST


Expand All @@ -37,9 +37,8 @@
@pytest.mark.slow
@require_torch_accelerator
@require_peft
class DPOTrainerSlowTester(TrlTestCase):
def setUp(self):
super().setUp()
class TestDPOTrainerSlow(TrlTestCase):
def setup_method(self):
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
self.peft_config = LoraConfig(
lora_alpha=16,
Expand All @@ -50,11 +49,10 @@ def setUp(self):
)
self.max_length = 128

def tearDown(self):
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
super().tearDown()

@parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS)))
def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
Expand Down Expand Up @@ -151,8 +149,8 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
peft_config=self.peft_config,
)

self.assertIsInstance(trainer.model, PeftModel)
self.assertIsNone(trainer.ref_model)
assert isinstance(trainer.model, PeftModel)
assert trainer.ref_model is None

# train the model
trainer.train()
Expand Down Expand Up @@ -215,8 +213,8 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
peft_config=self.peft_config,
)

self.assertIsInstance(trainer.model, PeftModel)
self.assertIsNone(trainer.ref_model)
assert isinstance(trainer.model, PeftModel)
assert trainer.ref_model is None

# train the model
trainer.train()
Expand Down
53 changes: 23 additions & 30 deletions tests/slow/test_grpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
backend_empty_cache,
require_flash_attn,
require_liger_kernel,
require_peft,
require_torch_accelerator,
torch_device,
)
Expand All @@ -44,7 +43,7 @@
from trl import GRPOConfig, GRPOTrainer
from trl.trainer.utils import get_kbit_device_map

from ..testing_utils import TrlTestCase, require_bitsandbytes, require_vllm
from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft, require_vllm
from .testing_constants import MODELS_TO_TEST


Expand All @@ -54,18 +53,16 @@

@pytest.mark.slow
@require_torch_accelerator
class GRPOTrainerSlowTester(TrlTestCase):
def setUp(self):
super().setUp()
class TestGRPOTrainerSlow(TrlTestCase):
def setup_method(self):
self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test")
self.max_length = 128

def tearDown(self):
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
super().tearDown()

@parameterized.expand(MODELS_TO_TEST)
@require_liger_kernel
Expand Down Expand Up @@ -103,7 +100,7 @@ def test_training_with_liger_grpo_loss(self, model_name):

for n, param in previous_trainable_params.items():
new_param = model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

release_memory(model, trainer)

Expand Down Expand Up @@ -153,20 +150,20 @@ def test_training_with_liger_grpo_loss_and_peft(self, model_name):
# Verify PEFT adapter is properly initialized
from peft import PeftModel

self.assertTrue(isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT")
assert isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT"

# Store adapter weights before training
previous_trainable_params = {
n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad
}
self.assertTrue(len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model")
assert len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model"

trainer.train()

# Verify adapter weights have changed after training
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

release_memory(model, trainer)

Expand Down Expand Up @@ -199,12 +196,12 @@ def test_training_with_transformers_paged(self, model_name):

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

release_memory(model, trainer)

Expand Down Expand Up @@ -310,13 +307,13 @@ def reward_func(prompts, completions, **kwargs):
peft_config=lora_config,
)

self.assertIsInstance(trainer.model, PeftModel)
assert isinstance(trainer.model, PeftModel)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that LoRA parameters have changed
# For VLM models, we're more permissive about which parameters can change
Expand All @@ -328,7 +325,7 @@ def reward_func(prompts, completions, **kwargs):
lora_params_changed = True

# At least some LoRA parameters should have changed during training
self.assertTrue(lora_params_changed, "No LoRA parameters were updated during training.")
assert lora_params_changed, "No LoRA parameters were updated during training."

except torch.OutOfMemoryError as e:
self.skipTest(f"Skipping VLM training test due to insufficient GPU memory: {e}")
Expand Down Expand Up @@ -378,8 +375,8 @@ def test_vlm_processor_vllm_colocate_mode(self):
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct", use_fast=True, padding_side="left")

# Verify processor has both required attributes for VLM detection
self.assertTrue(hasattr(processor, "tokenizer"))
self.assertTrue(hasattr(processor, "image_processor"))
assert hasattr(processor, "tokenizer")
assert hasattr(processor, "image_processor")

def dummy_reward_func(completions, **kwargs):
return [1.0] * len(completions)
Expand Down Expand Up @@ -438,16 +435,14 @@ def dummy_reward_func(completions, **kwargs):
)

# Should detect VLM processor correctly and allow vLLM
self.assertTrue(trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode")
self.assertEqual(trainer.vllm_mode, "colocate", "Should use colocate mode")
assert trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode"
assert trainer.vllm_mode == "colocate", "Should use colocate mode"

# Check if signature columns were set properly
if trainer._signature_columns is not None:
# Should include 'image' in signature columns for VLM processors
self.assertIn(
"image",
trainer._signature_columns,
"Should include 'image' in signature columns for VLM",
assert "image" in trainer._signature_columns, (
"Should include 'image' in signature columns for VLM"
)

# Should not emit any warnings about VLM incompatibility
Expand All @@ -457,10 +452,8 @@ def dummy_reward_func(completions, **kwargs):
if "does not support VLMs" in str(w_item.message)
or "not compatible" in str(w_item.message).lower()
]
self.assertEqual(
len(incompatibility_warnings),
0,
f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}",
assert len(incompatibility_warnings) == 0, (
f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}"
)

# Test passes if we get this far without exceptions
Expand Down Expand Up @@ -525,12 +518,12 @@ def test_training_vllm(self):

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

except Exception as e:
# If vLLM fails to initialize due to hardware constraints or other issues, that's expected
Expand Down
23 changes: 10 additions & 13 deletions tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from transformers.testing_utils import (
backend_empty_cache,
require_liger_kernel,
require_peft,
require_torch_accelerator,
require_torch_multi_accelerator,
torch_device,
Expand All @@ -33,7 +32,7 @@

from trl import SFTConfig, SFTTrainer

from ..testing_utils import TrlTestCase, require_bitsandbytes
from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS


Expand All @@ -44,9 +43,8 @@
@pytest.mark.slow
@require_torch_accelerator
@require_peft
class SFTTrainerSlowTester(TrlTestCase):
def setUp(self):
super().setUp()
class TestSFTTrainerSlow(TrlTestCase):
def setup_method(self):
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
self.max_length = 128
Expand All @@ -58,11 +56,10 @@ def setUp(self):
task_type="CAUSAL_LM",
)

def tearDown(self):
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
super().tearDown()

@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
def test_sft_trainer_str(self, model_name, packing):
Expand Down Expand Up @@ -148,7 +145,7 @@ def test_sft_trainer_peft(self, model_name, packing):
peft_config=self.peft_config,
)

self.assertIsInstance(trainer.model, PeftModel)
assert isinstance(trainer.model, PeftModel)

trainer.train()

Expand Down Expand Up @@ -252,7 +249,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
peft_config=self.peft_config,
)

self.assertIsInstance(trainer.model, PeftModel)
assert isinstance(trainer.model, PeftModel)

trainer.train()

Expand Down Expand Up @@ -332,7 +329,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
peft_config=self.peft_config,
)

self.assertIsInstance(trainer.model, PeftModel)
assert isinstance(trainer.model, PeftModel)

trainer.train()

Expand Down Expand Up @@ -372,7 +369,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
peft_config=self.peft_config,
)

self.assertIsInstance(trainer.model, PeftModel)
assert isinstance(trainer.model, PeftModel)

trainer.train()

Expand Down Expand Up @@ -447,11 +444,11 @@ def test_train_offloading(self, model_name, packing):
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[-1]["train_loss"] is not None

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"

release_memory(trainer.model, trainer)
15 changes: 7 additions & 8 deletions tests/test_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import torch
from torch import nn
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_peft, require_torch_accelerator, torch_device
from transformers.testing_utils import require_torch_accelerator, torch_device
from transformers.utils import is_peft_available

from trl.models.activation_offloading import NoOpManager, OffloadActivations

from .testing_utils import TrlTestCase
from .testing_utils import TrlTestCase, require_peft


if is_peft_available():
Expand Down Expand Up @@ -72,9 +72,8 @@ def test_offloading_with_peft_models(self) -> None:
for name_orig, grad_orig in grads_original:
for name_param, param in model.named_parameters():
if name_param == name_orig and param.requires_grad and param.grad is not None:
self.assertTrue(
torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5),
f"Gradient mismatch for {name_orig}",
assert torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), (
f"Gradient mismatch for {name_orig}"
)

@require_torch_accelerator
Expand Down Expand Up @@ -105,7 +104,7 @@ def test_noop_manager_with_offloading(self):

# Gradients should match as NoOpManager should have prevented offloading
for g1, g2 in zip(grads1, grads2):
self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5))
assert torch.allclose(g1, g2, rtol=1e-4, atol=1e-5)

@require_torch_accelerator
def test_min_offload_size(self):
Expand Down Expand Up @@ -152,6 +151,6 @@ def test_real_hf_model(self):
grads2 = [p.grad.clone() for p in model.parameters()]

# Check outputs and gradients match
self.assertTrue(torch.allclose(out1, out2, rtol=1e-5))
assert torch.allclose(out1, out2, rtol=1e-5)
for g1, g2 in zip(grads1, grads2):
self.assertTrue(torch.allclose(g1, g2, rtol=1e-5))
assert torch.allclose(g1, g2, rtol=1e-5)
Loading
Loading