Skip to content

Commit 3708793

Browse files
TST Extend LoftQ tests to check CPU initialization (#1274)
Tests to complement PR #1256
1 parent 46a84bd commit 3708793

File tree

1 file changed

+42
-16
lines changed

1 file changed

+42
-16
lines changed

tests/test_gpu_examples.py

+42-16
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from accelerate.test_utils.testing import run_command
2525
from accelerate.utils import patch_environment
2626
from datasets import Audio, DatasetDict, load_dataset
27+
from parameterized import parameterized
2728
from transformers import (
2829
AutoModelForCausalLM,
2930
AutoModelForSeq2SeqLM,
@@ -950,16 +951,31 @@ def setUp(self):
950951
self.error_factor = 3
951952
self.model_id = "hf-internal-testing/tiny-random-BloomForCausalLM"
952953
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
953-
self.inputs = self.tokenizer("All I want is", padding=True, return_tensors="pt").to("cuda")
954954

955-
def get_errors(self, bits=4, loftq_iter=1):
955+
def get_input(self, device):
956+
inputs = self.tokenizer("All I want is", padding=True, return_tensors="pt")
957+
if device == "cuda":
958+
inputs = inputs.to("cuda")
959+
return inputs
960+
961+
def get_base_model(self, model_id, device, **kwargs):
962+
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval()
963+
if device == "cuda":
964+
model = model.to("cuda")
965+
return model
966+
967+
def get_errors(self, bits=4, loftq_iter=1, device="cuda"):
956968
# Helper function that returns the quantization errors (MAE and MSE) when comparing the quantized LoRA model
957969
# to the base model, vs the LoftQ quantized model to the base model. We expect the LoftQ quantized model to
958970
# have less error than the normal LoRA quantized model. Since we compare logits, the observed error is
959971
# already somewhat dampened because of the softmax.
960-
model = AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval()
972+
model = self.get_base_model(self.model_id, device)
973+
if device == "cuda":
974+
model = model.to("cuda")
975+
961976
torch.manual_seed(0)
962-
logits_base = model(**self.inputs).logits
977+
inputs = self.get_input(device)
978+
logits_base = model(**inputs).logits
963979
# clean up
964980
del model
965981
gc.collect()
@@ -976,21 +992,27 @@ def get_errors(self, bits=4, loftq_iter=1):
976992
raise ValueError("bits must be 4 or 8")
977993

978994
quantized_model = get_peft_model(
979-
AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto", **kwargs).eval(),
995+
self.get_base_model(self.model_id, device=None, **kwargs),
980996
lora_config,
981997
)
982998
torch.manual_seed(0)
983-
logits_quantized = quantized_model(**self.inputs).logits
999+
logits_quantized = quantized_model(**inputs).logits
9841000
del quantized_model
9851001
gc.collect()
9861002
torch.cuda.empty_cache()
9871003

9881004
# logits from quantized LoRA model using LoftQ
9891005
loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter)
9901006
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights="loftq", loftq_config=loftq_config)
991-
loftq_model = get_peft_model(AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval(), lora_config)
1007+
model = self.get_base_model(self.model_id, device)
1008+
if device == "cuda":
1009+
model = model.to("cuda")
1010+
loftq_model = get_peft_model(model, lora_config)
1011+
if device == "cuda":
1012+
loftq_model = loftq_model.to("cuda")
1013+
9921014
torch.manual_seed(0)
993-
logits_loftq = loftq_model(**self.inputs).logits
1015+
logits_loftq = loftq_model(**inputs).logits
9941016
del loftq_model
9951017
gc.collect()
9961018
torch.cuda.empty_cache()
@@ -1001,14 +1023,15 @@ def get_errors(self, bits=4, loftq_iter=1):
10011023
mse_loftq = torch.pow(logits_base - logits_loftq, 2).mean()
10021024
return mae_quantized, mse_quantized, mae_loftq, mse_loftq
10031025

1004-
def test_bloomz_loftq_4bit(self):
1026+
@parameterized.expand(["cuda", "cpu"])
1027+
def test_bloomz_loftq_4bit(self, device):
10051028
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
10061029
# using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized
10071030
# model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the
10081031
# quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training.
10091032
# We still apply LoRA for the test for consistency.
10101033

1011-
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4)
1034+
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4, device=device)
10121035
# first, sanity check that all errors are > 0.0
10131036
self.assertTrue(mae_quantized > 0.0)
10141037
self.assertTrue(mse_quantized > 0.0)
@@ -1020,10 +1043,11 @@ def test_bloomz_loftq_4bit(self):
10201043
self.assertTrue(mae_loftq < mae_quantized / factor)
10211044
self.assertTrue(mse_loftq < mse_quantized / factor)
10221045

1023-
def test_bloomz_loftq_4bit_iter_5(self):
1046+
@parameterized.expand(["cuda", "cpu"])
1047+
def test_bloomz_loftq_4bit_iter_5(self, device):
10241048
# Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more
10251049
# iterations, but in practice the difference is not that large, at least not for this small base model.
1026-
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4, loftq_iter=5)
1050+
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4, loftq_iter=5, device=device)
10271051
# first, sanity check that all errors are > 0.0
10281052
self.assertTrue(mae_quantized > 0.0)
10291053
self.assertTrue(mse_quantized > 0.0)
@@ -1034,14 +1058,15 @@ def test_bloomz_loftq_4bit_iter_5(self):
10341058
self.assertTrue(mae_loftq < mae_quantized / self.error_factor)
10351059
self.assertTrue(mse_loftq < mse_quantized / self.error_factor)
10361060

1037-
def test_bloomz_loftq_8bit(self):
1061+
@parameterized.expand(["cuda", "cpu"])
1062+
def test_bloomz_loftq_8bit(self, device):
10381063
# this currently does not work:
10391064
# https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499
10401065
if True: # TODO: remove as soon as the issue is fixed
10411066
return
10421067

10431068
# Same test as test_bloomz_loftq_4bit but with 8 bits.
1044-
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8)
1069+
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, device=device)
10451070

10461071
# first, sanity check that all errors are > 0.0
10471072
self.assertTrue(mae_quantized > 0.0)
@@ -1053,14 +1078,15 @@ def test_bloomz_loftq_8bit(self):
10531078
self.assertTrue(mae_loftq < mae_quantized / self.error_factor)
10541079
self.assertTrue(mse_loftq < mse_quantized / self.error_factor)
10551080

1056-
def test_bloomz_loftq_8bit_iter_5(self):
1081+
@parameterized.expand(["cuda", "cpu"])
1082+
def test_bloomz_loftq_8bit_iter_5(self, device):
10571083
# this currently does not work:
10581084
# https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499
10591085
if True: # TODO: remove as soon as the issue is fixed
10601086
return
10611087

10621088
# Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits.
1063-
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, loftq_iter=5)
1089+
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, loftq_iter=5, device=device)
10641090

10651091
# first, sanity check that all errors are > 0.0
10661092
self.assertTrue(mae_quantized > 0.0)

0 commit comments

Comments
 (0)