Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make cuda-only tests device-agnostic #2044

Merged
merged 6 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@

from trl import DPOConfig, DPOTrainer, is_peft_available

from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu
from ..testing_utils import require_bitsandbytes, require_non_cpu, require_peft, torch_device
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST


if is_peft_available():
from peft import LoraConfig, PeftModel


@require_torch_gpu
@require_non_cpu
class DPOTrainerSlowTester(unittest.TestCase):
def setUp(self):
self.dataset = load_dataset("trl-internal-testing/mlabonne-chatml-dpo-pairs-copy", split="train[:10%]")
Expand All @@ -47,7 +47,10 @@ def setUp(self):

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
if torch_device == "cpu":
torch.cuda.empty_cache()
elif torch_device == "xpu":
torch.xpu.empty_cache()
gc.collect()

@parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS)))
Expand Down
8 changes: 4 additions & 4 deletions tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from ..testing_utils import (
require_bitsandbytes,
require_liger_kernel,
require_multi_accelerator,
require_non_cpu,
require_peft,
require_torch_gpu,
require_torch_multi_gpu,
)
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS

Expand All @@ -39,7 +39,7 @@
from peft import LoraConfig, PeftModel


@require_torch_gpu
@require_non_cpu
class SFTTrainerSlowTester(unittest.TestCase):
def setUp(self):
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
@parameterized.expand(
list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, DEVICE_MAP_OPTIONS))
)
@require_torch_multi_gpu
@require_multi_accelerator
def test_sft_trainer_transformers_mp_gc_device_map(
self, model_name, packing, gradient_checkpointing_kwargs, device_map
):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from trl.core import respond_to_batch

from .testing_constants import CI_HUB_ENDPOINT, CI_HUB_USER
from .testing_utils import require_peft, require_torch_multi_gpu
from .testing_utils import require_multi_accelerator, require_peft


EXPECTED_STATS = [
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def test_push_to_hub(self):
)

@require_peft
@require_torch_multi_gpu
@require_multi_accelerator
def test_peft_model_ppo_trainer_multi_gpu(self):
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM
Expand Down
22 changes: 22 additions & 0 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import unittest

import torch
from accelerate.test_utils.testing import get_backend

from trl import (
is_bitsandbytes_available,
Expand All @@ -26,6 +27,9 @@
)


torch_device, device_count, memory_allocated_func = get_backend()


def require_peft(test_case):
"""
Decorator marking a test that requires peft. Skips the test if peft is not available.
Expand Down Expand Up @@ -115,3 +119,21 @@ def require_liger_kernel(test_case):
if not (torch.cuda.is_available() and is_liger_available()):
test_case = unittest.skip("test requires GPU and liger-kernel")(test_case)
return test_case


def require_non_cpu(test_case):
"""
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
hardware accelerator available.
"""
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)


def require_multi_accelerator(test_case):
"""
Decorator marking a test that requires multiple hardware accelerators. These tests are skipped on a machine without
multiple accelerators.
"""
return unittest.skipUnless(
torch_device != "cpu" and device_count > 1, "test requires multiple hardware accelerators"
)(test_case)
Loading