Skip to content

Commit

Permalink
TST Make tests more work with MPS (huggingface#1463)
Browse files Browse the repository at this point in the history
  • Loading branch information
akx authored and BenjaminBossan committed Mar 14, 2024
1 parent ed9ef60 commit d0ef896
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 51 deletions.
14 changes: 6 additions & 8 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,16 @@


# Get current device name based on available devices
def infer_device():
def infer_device() -> str:
if torch.cuda.is_available():
torch_device = "cuda"
return "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch_device = torch.device("mps")
return "mps"
elif is_xpu_available():
torch_device = "xpu"
return "xpu"
elif is_npu_available():
torch_device = "npu"
else:
torch_device = "cpu"
return torch_device
return "npu"
return "cpu"


def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_adaption_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest
from unittest import TestCase

import pytest
import torch
from torch.testing import assert_close

Expand Down Expand Up @@ -403,6 +404,9 @@ def test_use_cache(self) -> None:
assert_close(expected, actual, rtol=0, atol=0)

def test_bf16_inference(self) -> None:
if self.torch_device == "mps":
return pytest.skip("Skipping bf16 test on MPS")

"""Test that AdaptionPrompt works when Llama using a half-precision model."""
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
original = LlamaForCausalLM.from_pretrained(
Expand Down
51 changes: 27 additions & 24 deletions tests/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)
from peft.utils import infer_device


class PeftAutoModelTester(unittest.TestCase):
dtype = torch.float16 if infer_device() == "mps" else torch.bfloat16

def test_peft_causal_lm(self):
model_id = "peft-internal-testing/tiny-OPTForCausalLM-lora"
model = AutoPeftModelForCausalLM.from_pretrained(model_id)
Expand All @@ -47,29 +50,29 @@ def test_peft_causal_lm(self):
assert isinstance(model, PeftModelForCausalLM)

# check if kwargs are passed correctly
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForCausalLM)
assert model.base_model.lm_head.weight.dtype == torch.bfloat16
assert model.base_model.lm_head.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)

def test_peft_causal_lm_extended_vocab(self):
model_id = "peft-internal-testing/tiny-random-OPTForCausalLM-extended-vocab"
model = AutoPeftModelForCausalLM.from_pretrained(model_id)
assert isinstance(model, PeftModelForCausalLM)

# check if kwargs are passed correctly
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForCausalLM)
assert model.base_model.lm_head.weight.dtype == torch.bfloat16
assert model.base_model.lm_head.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)

def test_peft_seq2seq_lm(self):
model_id = "peft-internal-testing/tiny_T5ForSeq2SeqLM-lora"
Expand All @@ -83,14 +86,14 @@ def test_peft_seq2seq_lm(self):
assert isinstance(model, PeftModelForSeq2SeqLM)

# check if kwargs are passed correctly
model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForSeq2SeqLM)
assert model.base_model.lm_head.weight.dtype == torch.bfloat16
assert model.base_model.lm_head.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)
_ = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)

def test_peft_sequence_cls(self):
model_id = "peft-internal-testing/tiny_OPTForSequenceClassification-lora"
Expand All @@ -104,15 +107,15 @@ def test_peft_sequence_cls(self):
assert isinstance(model, PeftModelForSequenceClassification)

# check if kwargs are passed correctly
model = AutoPeftModelForSequenceClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForSequenceClassification.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForSequenceClassification)
assert model.score.original_module.weight.dtype == torch.bfloat16
assert model.score.original_module.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForSequenceClassification.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
)

def test_peft_token_classification(self):
Expand All @@ -127,15 +130,15 @@ def test_peft_token_classification(self):
assert isinstance(model, PeftModelForTokenClassification)

# check if kwargs are passed correctly
model = AutoPeftModelForTokenClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForTokenClassification.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForTokenClassification)
assert model.base_model.classifier.original_module.weight.dtype == torch.bfloat16
assert model.base_model.classifier.original_module.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForTokenClassification.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
)

def test_peft_question_answering(self):
Expand All @@ -150,15 +153,15 @@ def test_peft_question_answering(self):
assert isinstance(model, PeftModelForQuestionAnswering)

# check if kwargs are passed correctly
model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForQuestionAnswering)
assert model.base_model.qa_outputs.original_module.weight.dtype == torch.bfloat16
assert model.base_model.qa_outputs.original_module.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForQuestionAnswering.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
)

def test_peft_feature_extraction(self):
Expand All @@ -173,15 +176,15 @@ def test_peft_feature_extraction(self):
assert isinstance(model, PeftModelForFeatureExtraction)

# check if kwargs are passed correctly
model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForFeatureExtraction)
assert model.base_model.model.decoder.embed_tokens.weight.dtype == torch.bfloat16
assert model.base_model.model.decoder.embed_tokens.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForFeatureExtraction.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
)

def test_peft_whisper(self):
Expand All @@ -196,11 +199,11 @@ def test_peft_whisper(self):
assert isinstance(model, PeftModel)

# check if kwargs are passed correctly
model = AutoPeftModel.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModel.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModel)
assert model.base_model.model.model.encoder.embed_positions.weight.dtype == torch.bfloat16
assert model.base_model.model.model.encoder.embed_positions.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModel.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)
_ = AutoPeftModel.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)
2 changes: 1 addition & 1 deletion tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls
def test_prompt_tuning_text_prepare_for_training(self, test_name, model_id, config_cls, config_kwargs):
# Test that prompt tuning works with text init
if config_cls != PromptTuningConfig:
return
return pytest.skip(f"This test does not apply to {config_cls}")

config_kwargs = config_kwargs.copy()
config_kwargs["prompt_tuning_init"] = PromptTuningInit.TEXT
Expand Down
40 changes: 22 additions & 18 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs, safe_serial
def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs, safe_serialization=True):
if issubclass(config_cls, (AdaLoraConfig, VeraConfig)):
# AdaLora does not support adding more than 1 adapter
return
return pytest.skip(f"Test not applicable for {config_cls}")

# ensure that the weights are randomly initialized
if issubclass(config_cls, LoraConfig):
Expand Down Expand Up @@ -446,9 +446,9 @@ def _test_from_pretrained_config_construction(self, model_id, config_cls, config
assert model_from_pretrained.peft_config["default"] is config

def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs):
if config_cls not in (LoraConfig, IA3Config, VeraConfig):
# Merge layers only supported for LoRA, IA³, VeRA
return
if config_cls not in (LoraConfig, IA3Config):
# Merge layers only supported for LoRA and IA³
return pytest.skip(f"Test not applicable for {config_cls}")

if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")
Expand Down Expand Up @@ -528,7 +528,8 @@ def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs):

def _test_merge_layers(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, PromptLearningConfig):
return pytest.skip(f"Test not applicable for {config_cls}")
return pytest.skip(f"Test not applicable for {config_cls}")

if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")

Expand Down Expand Up @@ -702,8 +703,11 @@ def _test_generate_pos_args(self, model_id, config_cls, config_kwargs, raises_er
_ = model.generate(inputs["input_ids"])

def _test_generate_half_prec(self, model_id, config_cls, config_kwargs):
if config_cls not in (IA3Config, LoraConfig, VeraConfig, PrefixTuningConfig):
return
if config_cls not in (IA3Config, LoraConfig, PrefixTuningConfig):
return pytest.skip(f"Test not applicable for {config_cls}")

if self.torch_device == "mps": # BFloat16 is not supported on MPS
return pytest.skip("BFloat16 is not supported on MPS")

model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.bfloat16)
config = config_cls(
Expand All @@ -721,7 +725,7 @@ def _test_generate_half_prec(self, model_id, config_cls, config_kwargs):

def _test_prefix_tuning_half_prec_conversion(self, model_id, config_cls, config_kwargs):
if config_cls not in (PrefixTuningConfig,):
return
return pytest.skip(f"Test not applicable for {config_cls}")

config = config_cls(
base_model_name_or_path=model_id,
Expand Down Expand Up @@ -802,7 +806,7 @@ def _test_inference_safetensors(self, model_id, config_cls, config_kwargs):

def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
if config_cls not in (LoraConfig,):
return
return pytest.skip(f"Test not applicable for {config_cls}")

config = config_cls(
base_model_name_or_path=model_id,
Expand Down Expand Up @@ -865,7 +869,7 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa
model = self.transformers_class.from_pretrained(model_id)

if not getattr(model, "supports_gradient_checkpointing", False):
return
return pytest.skip(f"Model {model_id} does not support gradient checkpointing")

model.gradient_checkpointing_enable()

Expand Down Expand Up @@ -898,7 +902,7 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa

def _test_peft_model_device_map(self, model_id, config_cls, config_kwargs):
if config_cls not in (LoraConfig,):
return
return pytest.skip(f"Test not applicable for {config_cls}")

config = config_cls(
base_model_name_or_path=model_id,
Expand All @@ -920,7 +924,7 @@ def _test_peft_model_device_map(self, model_id, config_cls, config_kwargs):

def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwargs):
if not issubclass(config_cls, PromptLearningConfig):
return
return pytest.skip(f"Test not applicable for {config_cls}")

model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
Expand Down Expand Up @@ -950,7 +954,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
**config_kwargs,
)
if config.peft_type not in supported_peft_types:
return
return pytest.skip(f"Test not applicable for {config.peft_type}")

model = self.transformers_class.from_pretrained(model_id)
adapter_to_delete = "delete_me"
Expand Down Expand Up @@ -988,7 +992,7 @@ def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
**config_kwargs,
)
if config.peft_type not in supported_peft_types:
return
return pytest.skip(f"Test not applicable for {config.peft_type}")

model = self.transformers_class.from_pretrained(model_id)
adapter_to_delete = "delete_me"
Expand Down Expand Up @@ -1045,16 +1049,16 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, (AdaLoraConfig, VeraConfig)):
# AdaLora does not support adding more than 1 adapter
return
return pytest.skip(f"Test not applicable for {config_cls}")

adapter_list = ["adapter1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5]
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
if not isinstance(config, (LoraConfig)):
return
if not isinstance(config, LoraConfig):
return pytest.skip(f"Test not applicable for {config}")

model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0])
Expand Down Expand Up @@ -1314,7 +1318,7 @@ def _test_adding_multiple_adapters_with_bias_raises(self, model_id, config_cls,
# raised. Also, the peft model should not be left in a half-initialized state.

if not issubclass(config_cls, (LoraConfig, AdaLoraConfig)):
return
return pytest.skip(f"Test not applicable for {config_cls}")

config_kwargs = config_kwargs.copy()
config_kwargs["bias"] = "all"
Expand Down

0 comments on commit d0ef896

Please sign in to comment.