From 56052c8964eebc26a2cdd4337c7a6a1698fca648 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 4 Dec 2023 16:58:44 +0100 Subject: [PATCH 1/3] Add tests for 4bit LoftQ Add GPU tests for LoftQ with 4bit quantization. Notes Tests for 8bit quantization are already there but not run at the moment, see this comment: https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499 In my testing, 8bit passes when using NFQuantizer, so if the original author is fine with using that, I can make the adjustment. --- tests/test_gpu_examples.py | 137 +++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 1af1919ad3..cf9ebb257c 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -39,7 +39,9 @@ from peft import ( AdaLoraConfig, + LoftQConfig, LoraConfig, + TaskType, get_peft_model, prepare_model_for_int8_training, prepare_model_for_kbit_training, @@ -933,3 +935,138 @@ def test_causal_lm_training_mutli_gpu(self): # assert loss is not None self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + +@require_torch_gpu +class LoftQTests(unittest.TestCase): + r""" + Tests for LoftQ + """ + + def setUp(self): + self.model_id = "hf-internal-testing/tiny-random-BloomForCausalLM" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.inputs = self.tokenizer("All I want is", padding=True, return_tensors="pt").to("cuda") + + def get_errors(self, bits=4, loftq_iter=1): + # Helper function that returns the quantization errors (MAE and MSE) when comparing the quantized LoRA model + # to the base model, vs the LoftQ quantized model to the base model. We expect the LoftQ quantized model to + # have less error than the normal LoRA quantized model. Since we compare logits, the observed error is + # already somewhat dampened because of the softmax. + model = AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval() + torch.manual_seed(0) + logits_base = model(**self.inputs).logits + # clean up + del model + gc.collect() + torch.cuda.empty_cache() + + # logits from the normal quantized LoRA model + lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM) + kwargs = {} + if bits == 4: + kwargs["load_in_4bit"] = True + elif bits == 8: + kwargs["load_in_8bit"] = True + else: + raise ValueError("bits must be 4 or 8") + + quantized_model = get_peft_model( + AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto", **kwargs).eval(), + lora_config, + ) + torch.manual_seed(0) + logits_quantized = quantized_model(**self.inputs).logits + del quantized_model + gc.collect() + torch.cuda.empty_cache() + + # logits from quantized LoRA model using LoftQ + loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter) + lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights="loftq", loftq_config=loftq_config) + loftq_model = get_peft_model(AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval(), lora_config) + torch.manual_seed(0) + logits_loftq = loftq_model(**self.inputs).logits + del loftq_model + gc.collect() + torch.cuda.empty_cache() + + mae_quantized = torch.abs(logits_base - logits_quantized).mean() + mse_quantized = torch.pow(logits_base - logits_quantized, 2).mean() + mae_loftq = torch.abs(logits_base - logits_loftq).mean() + mse_loftq = torch.pow(logits_base - logits_loftq, 2).mean() + return mae_quantized, mse_quantized, mae_loftq, mse_loftq + + def test_bloomz_loftq_4bit(self): + # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model + # using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized + # model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the + # quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training. + # We still apply LoRA for the test for consistency. + + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4) + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + factor = 3 + self.assertTrue(mae_loftq < mae_quantized / factor) + self.assertTrue(mse_loftq < mse_quantized / factor) + + def test_bloomz_loftq_4bit_iter_5(self): + # Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more + # iterations, but in practice the difference is not that large, at least not for this small base model. + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4, loftq_iter=5) + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + factor = 3 + self.assertTrue(mae_loftq < mae_quantized / factor) + self.assertTrue(mse_loftq < mse_quantized / factor) + + def test_bloomz_loftq_8bit(self): + # this currently does not work: + # https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499 + if True: # TODO: remove as soon as the issue is fixed + return + + # Same test as test_bloomz_loftq_4bit but with 8 bits. + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8) + + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + factor = 3 + self.assertTrue(mae_loftq < mae_quantized / factor) + self.assertTrue(mse_loftq < mse_quantized / factor) + + def test_bloomz_loftq_8bit_iter_5(self): + # this currently does not work: + # https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499 + if True: # TODO: remove as soon as the issue is fixed + return + + # Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits. + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, loftq_iter=5) + + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + factor = 3 + self.assertTrue(mae_loftq < mae_quantized / factor) + self.assertTrue(mse_loftq < mse_quantized / factor) From 54e2a577c9e8bfe04ff6fd9441df39fcbda91057 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 5 Dec 2023 18:14:14 +0100 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- tests/test_gpu_examples.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index cf9ebb257c..7d276252df 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -942,7 +942,7 @@ class LoftQTests(unittest.TestCase): r""" Tests for LoftQ """ - + error_factor = 3 def setUp(self): self.model_id = "hf-internal-testing/tiny-random-BloomForCausalLM" self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) @@ -1027,9 +1027,8 @@ def test_bloomz_loftq_4bit_iter_5(self): self.assertTrue(mse_loftq > 0.0) # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - factor = 3 - self.assertTrue(mae_loftq < mae_quantized / factor) - self.assertTrue(mse_loftq < mse_quantized / factor) + self.assertTrue(mae_loftq < mae_quantized / self.error_factor) + self.assertTrue(mse_loftq < mse_quantized / self.error_factor) def test_bloomz_loftq_8bit(self): # this currently does not work: @@ -1047,9 +1046,8 @@ def test_bloomz_loftq_8bit(self): self.assertTrue(mse_loftq > 0.0) # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - factor = 3 - self.assertTrue(mae_loftq < mae_quantized / factor) - self.assertTrue(mse_loftq < mse_quantized / factor) + self.assertTrue(mae_loftq < mae_quantized / self.error_factor) + self.assertTrue(mse_loftq < mse_quantized / self.error_factor) def test_bloomz_loftq_8bit_iter_5(self): # this currently does not work: @@ -1067,6 +1065,5 @@ def test_bloomz_loftq_8bit_iter_5(self): self.assertTrue(mse_loftq > 0.0) # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - factor = 3 - self.assertTrue(mae_loftq < mae_quantized / factor) - self.assertTrue(mse_loftq < mse_quantized / factor) + self.assertTrue(mae_loftq < mae_quantized / self.error_factor) + self.assertTrue(mse_loftq < mse_quantized / self.error_factor) From f5bafb23f484dedcd51af2bbad86e7298f4e372e Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 5 Dec 2023 18:34:27 +0100 Subject: [PATCH 3/3] Make style --- tests/test_gpu_examples.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 7d276252df..ad70ae45dd 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -942,8 +942,9 @@ class LoftQTests(unittest.TestCase): r""" Tests for LoftQ """ - error_factor = 3 + def setUp(self): + self.error_factor = 3 self.model_id = "hf-internal-testing/tiny-random-BloomForCausalLM" self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.inputs = self.tokenizer("All I want is", padding=True, return_tensors="pt").to("cuda")