Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST: Add tests for 4bit LoftQ #1208

Merged
merged 4 commits into from
Dec 11, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@

from peft import (
AdaLoraConfig,
LoftQConfig,
LoraConfig,
TaskType,
get_peft_model,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
Expand Down Expand Up @@ -937,6 +939,139 @@ def test_causal_lm_training_mutli_gpu(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])


@require_torch_gpu
class LoftQTests(unittest.TestCase):
r"""
Tests for LoftQ
"""

def setUp(self):
self.error_factor = 3
self.model_id = "hf-internal-testing/tiny-random-BloomForCausalLM"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.inputs = self.tokenizer("All I want is", padding=True, return_tensors="pt").to("cuda")

def get_errors(self, bits=4, loftq_iter=1):
# Helper function that returns the quantization errors (MAE and MSE) when comparing the quantized LoRA model
# to the base model, vs the LoftQ quantized model to the base model. We expect the LoftQ quantized model to
# have less error than the normal LoRA quantized model. Since we compare logits, the observed error is
# already somewhat dampened because of the softmax.
model = AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval()
torch.manual_seed(0)
logits_base = model(**self.inputs).logits
# clean up
del model
gc.collect()
torch.cuda.empty_cache()

# logits from the normal quantized LoRA model
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM)
kwargs = {}
if bits == 4:
kwargs["load_in_4bit"] = True
elif bits == 8:
kwargs["load_in_8bit"] = True
else:
raise ValueError("bits must be 4 or 8")

quantized_model = get_peft_model(
AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto", **kwargs).eval(),
lora_config,
)
torch.manual_seed(0)
logits_quantized = quantized_model(**self.inputs).logits
del quantized_model
gc.collect()
torch.cuda.empty_cache()

# logits from quantized LoRA model using LoftQ
loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter)
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights="loftq", loftq_config=loftq_config)
loftq_model = get_peft_model(AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval(), lora_config)
torch.manual_seed(0)
logits_loftq = loftq_model(**self.inputs).logits
del loftq_model
gc.collect()
torch.cuda.empty_cache()

mae_quantized = torch.abs(logits_base - logits_quantized).mean()
mse_quantized = torch.pow(logits_base - logits_quantized, 2).mean()
mae_loftq = torch.abs(logits_base - logits_loftq).mean()
mse_loftq = torch.pow(logits_base - logits_loftq, 2).mean()
return mae_quantized, mse_quantized, mae_loftq, mse_loftq

def test_bloomz_loftq_4bit(self):
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
# using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized
# model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the
# quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training.
# We still apply LoRA for the test for consistency.

mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4)
# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
factor = 3
self.assertTrue(mae_loftq < mae_quantized / factor)
self.assertTrue(mse_loftq < mse_quantized / factor)

def test_bloomz_loftq_4bit_iter_5(self):
# Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more
# iterations, but in practice the difference is not that large, at least not for this small base model.
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4, loftq_iter=5)
# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
self.assertTrue(mae_loftq < mae_quantized / self.error_factor)
self.assertTrue(mse_loftq < mse_quantized / self.error_factor)

def test_bloomz_loftq_8bit(self):
# this currently does not work:
# https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499
if True: # TODO: remove as soon as the issue is fixed
return

# Same test as test_bloomz_loftq_4bit but with 8 bits.
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8)

# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
self.assertTrue(mae_loftq < mae_quantized / self.error_factor)
self.assertTrue(mse_loftq < mse_quantized / self.error_factor)

def test_bloomz_loftq_8bit_iter_5(self):
# this currently does not work:
# https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499
if True: # TODO: remove as soon as the issue is fixed
return

# Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits.
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, loftq_iter=5)

# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
self.assertTrue(mae_loftq < mae_quantized / self.error_factor)
self.assertTrue(mse_loftq < mse_quantized / self.error_factor)


@require_bitsandbytes
@require_torch_gpu
class MultiprocessTester(unittest.TestCase):
Expand Down