diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py index 26feb388c6b..3172d4137b8 100644 --- a/tests/slow/test_dpo_slow.py +++ b/tests/slow/test_dpo_slow.py @@ -21,12 +21,12 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from transformers.testing_utils import backend_empty_cache, require_torch_accelerator, torch_device +from transformers.testing_utils import backend_empty_cache, torch_device from transformers.utils import is_peft_available from trl import DPOConfig, DPOTrainer -from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft +from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft, require_torch_accelerator from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index 7ba974423ca..17e7cc33118 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -31,19 +31,21 @@ AutoTokenizer, BitsAndBytesConfig, ) -from transformers.testing_utils import ( - backend_empty_cache, - require_flash_attn, - require_liger_kernel, - require_torch_accelerator, - torch_device, -) +from transformers.testing_utils import backend_empty_cache, torch_device from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer from trl.trainer.utils import get_kbit_device_map -from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft, require_vllm +from ..testing_utils import ( + TrlTestCase, + require_bitsandbytes, + require_flash_attn, + require_liger_kernel, + require_peft, + require_torch_accelerator, + require_vllm, +) from .testing_constants import MODELS_TO_TEST diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index 13d9c7ce635..49f04991bb8 100755 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -21,18 +21,19 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from transformers.testing_utils import ( - backend_empty_cache, - require_liger_kernel, - require_torch_accelerator, - require_torch_multi_accelerator, - torch_device, -) +from transformers.testing_utils import backend_empty_cache, torch_device from transformers.utils import is_peft_available from trl import SFTConfig, SFTTrainer -from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft +from ..testing_utils import ( + TrlTestCase, + require_bitsandbytes, + require_liger_kernel, + require_peft, + require_torch_accelerator, + require_torch_multi_accelerator, +) from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index d1a9ea921f5..a116774e677 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -16,12 +16,12 @@ import torch from torch import nn from transformers import AutoModelForCausalLM -from transformers.testing_utils import require_torch_accelerator, torch_device +from transformers.testing_utils import torch_device from transformers.utils import is_peft_available from trl.models.activation_offloading import NoOpManager, OffloadActivations -from .testing_utils import TrlTestCase, require_peft +from .testing_utils import TrlTestCase, require_peft, require_torch_accelerator if is_peft_available(): diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 316c9b35ae1..811bcf79f37 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -18,7 +18,6 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments -from transformers.testing_utils import require_wandb from transformers.trainer_utils import get_last_checkpoint from transformers.utils import is_peft_available @@ -33,7 +32,7 @@ ) from trl.mergekit_utils import MergeConfig -from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft +from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft, require_wandb if is_peft_available(): diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index e33eba91c3d..8286e238c65 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -30,16 +30,14 @@ PreTrainedTokenizerBase, is_vision_available, ) -from transformers.testing_utils import ( - get_device_properties, - require_liger_kernel, -) +from transformers.testing_utils import get_device_properties from trl import DPOConfig, DPOTrainer, FDivergenceType from .testing_utils import ( TrlTestCase, require_bitsandbytes, + require_liger_kernel, require_no_wandb, require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled, diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py index 27d3c8ff3ba..c880486d5e3 100644 --- a/tests/test_gkd_trainer.py +++ b/tests/test_gkd_trainer.py @@ -19,12 +19,11 @@ import torch.nn.functional as F from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -from transformers.testing_utils import require_liger_kernel from trl import GKDConfig, GKDTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_liger_kernel class TestGKDTrainerGenerateOnPolicy(TrlTestCase): diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index a89c5f9bac0..5f777cce72d 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -24,7 +24,6 @@ AutoModelForSequenceClassification, AutoTokenizer, ) -from transformers.testing_utils import require_liger_kernel from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer @@ -35,7 +34,7 @@ ) from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer -from .testing_utils import TrlTestCase, require_peft, require_vision, require_vllm +from .testing_utils import TrlTestCase, require_liger_kernel, require_peft, require_vision, require_vllm if is_peft_available(): diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index e2c325149f2..43e535e4150 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -18,12 +18,11 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_liger_kernel from trl import KTOConfig, KTOTrainer from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize -from .testing_utils import TrlTestCase, require_no_wandb, require_peft +from .testing_utils import TrlTestCase, require_liger_kernel, require_no_wandb, require_peft class TestKTOTrainer(TrlTestCase): diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index f8706770371..84c90b844a2 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -18,7 +18,6 @@ from packaging.version import Version from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_torch_accelerator from transformers.utils import is_peft_available, is_vision_available from trl import OnlineDPOConfig, OnlineDPOTrainer @@ -28,6 +27,7 @@ TrlTestCase, require_llm_blender, require_peft, + require_torch_accelerator, require_vision, require_vllm, ) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 87e47b911d2..fa0c10d1df4 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -22,13 +22,20 @@ from packaging.version import parse as parse_version from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.testing_utils import require_flash_attn, require_liger_kernel from transformers.utils import is_peft_available from trl import SFTConfig, SFTTrainer from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss -from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, require_peft, require_vision +from .testing_utils import ( + TrlTestCase, + ignore_warnings, + require_bitsandbytes, + require_flash_attn, + require_liger_kernel, + require_peft, + require_vision, +) if is_peft_available(): diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index e36072f50c7..53b6614995e 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -17,12 +17,18 @@ import pytest from transformers import AutoModelForCausalLM -from transformers.testing_utils import require_torch_multi_accelerator, torch_device +from transformers.testing_utils import torch_device from trl.extras.vllm_client import VLLMClient from trl.scripts.vllm_serve import chunk_list -from .testing_utils import TrlTestCase, kill_process, require_3_accelerators, require_vllm +from .testing_utils import ( + TrlTestCase, + kill_process, + require_3_accelerators, + require_torch_multi_accelerator, + require_vllm, +) class TestChunkList(TrlTestCase): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 1d4992f3d7d..99a6e661f5c 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -22,12 +22,20 @@ import pytest import torch from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available -from transformers.testing_utils import torch_device -from transformers.utils import is_peft_available, is_rich_available, is_vision_available +from transformers.testing_utils import backend_device_count, torch_device +from transformers.utils import ( + is_flash_attn_2_available, + is_kernels_available, + is_peft_available, + is_rich_available, + is_torch_available, + is_vision_available, +) from trl import BaseBinaryJudge, BasePairwiseJudge from trl.import_utils import ( is_joblib_available, + is_liger_kernel_available, is_llm_blender_available, is_math_verify_available, is_mergekit_available, @@ -37,6 +45,7 @@ require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes") require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") +require_liger_kernel = pytest.mark.skipif(not is_liger_kernel_available(), reason="test requires liger-kernel") require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify") require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit") @@ -45,8 +54,15 @@ require_sklearn = pytest.mark.skipif( not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn" ) +require_torch_accelerator = pytest.mark.skipif( + torch_device is None or torch_device == "cpu", reason="test requires accelerator" +) +require_torch_multi_accelerator = pytest.mark.skipif( + not is_torch_available() or backend_device_count(torch_device) <= 1, reason="test requires multiple accelerators" +) require_vision = pytest.mark.skipif(not is_vision_available(), reason="test requires vision") require_vllm = pytest.mark.skipif(not is_vllm_available(), reason="test requires vllm") +require_wandb = pytest.mark.skipif(not is_wandb_available(), reason="test requires wandb") require_no_wandb = pytest.mark.skipif(is_wandb_available(), reason="test requires no wandb") require_3_accelerators = pytest.mark.skipif( not (getattr(torch, torch_device, torch.cuda).device_count() >= 3), @@ -69,6 +85,23 @@ def is_bitsandbytes_multi_backend_available() -> bool: ) +def is_flash_attn_available(): + flash_attn_available = is_flash_attn_2_available() + kernels_available = is_kernels_available() + try: + from kernels import get_kernel + + get_kernel("kernels-community/flash-attn") + except Exception: + kernels_available = False + + return kernels_available or flash_attn_available + + +# Function ported from transformers.testing_utils +require_flash_attn = pytest.mark.skipif(not is_flash_attn_available(), reason="test requires Flash Attention") + + class RandomBinaryJudge(BaseBinaryJudge): """ Random binary judge, for testing purposes.