Skip to content

Commit

Permalink
PeftAutoModelTester: use float16 on MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Feb 16, 2024
1 parent b5aed90 commit 6bd2fed
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions tests/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)
from peft.utils import infer_device


class PeftAutoModelTester(unittest.TestCase):
dtype = torch.float16 if infer_device() == "mps" else torch.bfloat16

def test_peft_causal_lm(self):
model_id = "peft-internal-testing/tiny-OPTForCausalLM-lora"
model = AutoPeftModelForCausalLM.from_pretrained(model_id)
Expand All @@ -47,29 +50,29 @@ def test_peft_causal_lm(self):
assert isinstance(model, PeftModelForCausalLM)

# check if kwargs are passed correctly
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForCausalLM)
assert model.base_model.lm_head.weight.dtype == torch.bfloat16
assert model.base_model.lm_head.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)

def test_peft_causal_lm_extended_vocab(self):
model_id = "peft-internal-testing/tiny-random-OPTForCausalLM-extended-vocab"
model = AutoPeftModelForCausalLM.from_pretrained(model_id)
assert isinstance(model, PeftModelForCausalLM)

# check if kwargs are passed correctly
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForCausalLM)
assert model.base_model.lm_head.weight.dtype == torch.bfloat16
assert model.base_model.lm_head.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)

def test_peft_seq2seq_lm(self):
model_id = "peft-internal-testing/tiny_T5ForSeq2SeqLM-lora"
Expand All @@ -83,14 +86,14 @@ def test_peft_seq2seq_lm(self):
assert isinstance(model, PeftModelForSeq2SeqLM)

# check if kwargs are passed correctly
model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForSeq2SeqLM)
assert model.base_model.lm_head.weight.dtype == torch.bfloat16
assert model.base_model.lm_head.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)
_ = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)

def test_peft_sequence_cls(self):
model_id = "peft-internal-testing/tiny_OPTForSequenceClassification-lora"
Expand All @@ -104,15 +107,15 @@ def test_peft_sequence_cls(self):
assert isinstance(model, PeftModelForSequenceClassification)

# check if kwargs are passed correctly
model = AutoPeftModelForSequenceClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForSequenceClassification.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForSequenceClassification)
assert model.score.original_module.weight.dtype == torch.bfloat16
assert model.score.original_module.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForSequenceClassification.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
)

def test_peft_token_classification(self):
Expand All @@ -127,15 +130,15 @@ def test_peft_token_classification(self):
assert isinstance(model, PeftModelForTokenClassification)

# check if kwargs are passed correctly
model = AutoPeftModelForTokenClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForTokenClassification.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForTokenClassification)
assert model.base_model.classifier.original_module.weight.dtype == torch.bfloat16
assert model.base_model.classifier.original_module.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForTokenClassification.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
)

def test_peft_question_answering(self):
Expand All @@ -150,15 +153,15 @@ def test_peft_question_answering(self):
assert isinstance(model, PeftModelForQuestionAnswering)

# check if kwargs are passed correctly
model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForQuestionAnswering)
assert model.base_model.qa_outputs.original_module.weight.dtype == torch.bfloat16
assert model.base_model.qa_outputs.original_module.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForQuestionAnswering.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
)

def test_peft_feature_extraction(self):
Expand All @@ -173,15 +176,15 @@ def test_peft_feature_extraction(self):
assert isinstance(model, PeftModelForFeatureExtraction)

# check if kwargs are passed correctly
model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModelForFeatureExtraction)
assert model.base_model.model.decoder.embed_tokens.weight.dtype == torch.bfloat16
assert model.base_model.model.decoder.embed_tokens.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForFeatureExtraction.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
)

def test_peft_whisper(self):
Expand All @@ -196,11 +199,11 @@ def test_peft_whisper(self):
assert isinstance(model, PeftModel)

# check if kwargs are passed correctly
model = AutoPeftModel.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = AutoPeftModel.from_pretrained(model_id, torch_dtype=self.dtype)
assert isinstance(model, PeftModel)
assert model.base_model.model.model.encoder.embed_positions.weight.dtype == torch.bfloat16
assert model.base_model.model.model.encoder.embed_positions.weight.dtype == self.dtype

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModel.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)
_ = AutoPeftModel.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)

0 comments on commit 6bd2fed

Please sign in to comment.