From f5d0f040403076f3e3a43c9f27b9ca5a33e723c5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 28 Mar 2023 12:54:40 +0000 Subject: [PATCH 1/7] add more tests --- tests/test_peft_model.py | 7 ++++++- tests/testing_common.py | 11 ++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index 2ca4895dd6..3fa4eeb3d1 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -32,7 +32,12 @@ # This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs PEFT_MODELS_TO_TEST = [ - ("hf-internal-testing/tiny-random-OPTForCausalLM", {"target_modules": ["q_proj", "v_proj"]}, {}, {}, {}), + ("hf-internal-testing/tiny-random-OPTForCausalLM", {}, {}, {}, {}), + ("HuggingFaceM4/tiny-random-LlamaForCausalLM", {}, {}, {}, {}), + ("hf-internal-testing/tiny-random-GPTNeoXForCausalLM", {}, {}, {}, {}), + ("hf-internal-testing/tiny-random-GPT2LMHeadModel", {}, {}, {}, {}), + ("hf-internal-testing/tiny-random-BloomForCausalLM", {}, {}, {}, {}), + ("hf-internal-testing/tiny-random-gpt_neo", {}, {}, {}, {}), ] diff --git a/tests/testing_common.py b/tests/testing_common.py index dfdf1d8be1..96c0fdb476 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -79,23 +79,24 @@ def get_grid_parameters(self, model_list): for model_tuple in model_list: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs = model_tuple for key, value in self.items(): + peft_method = value[1].copy() if key == "lora": # update value[1] if necessary if lora_kwargs is not None: - value[1].update(lora_kwargs) + peft_method.update(lora_kwargs) elif key == "prefix_tuning": # update value[1] if necessary if prefix_tuning_kwargs is not None: - value[1].update(prefix_tuning_kwargs) + peft_method.update(prefix_tuning_kwargs) elif key == "prompt_encoder": # update value[1] if necessary if prompt_encoder_kwargs is not None: - value[1].update(prompt_encoder_kwargs) + peft_method.update(prompt_encoder_kwargs) else: # update value[1] if necessary if prompt_tuning_kwargs is not None: - value[1].update(prompt_tuning_kwargs) - grid_parameters.append((f"test_{model_id}_{key}", model_id, value[0], value[1])) + peft_method.update(prompt_tuning_kwargs) + grid_parameters.append((f"test_{model_id}_{key}", model_id, value[0], peft_method)) return grid_parameters From d1e5a4f0852080682f09751a697f24e8d8bc8cf3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 28 Mar 2023 13:10:31 +0000 Subject: [PATCH 2/7] fix --- tests/test_peft_model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index 3fa4eeb3d1..8d3bc5b042 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -31,13 +31,14 @@ # This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs -PEFT_MODELS_TO_TEST = [ +PEFT_DECODER_MODELS_TO_TEST = [ ("hf-internal-testing/tiny-random-OPTForCausalLM", {}, {}, {}, {}), ("HuggingFaceM4/tiny-random-LlamaForCausalLM", {}, {}, {}, {}), ("hf-internal-testing/tiny-random-GPTNeoXForCausalLM", {}, {}, {}, {}), ("hf-internal-testing/tiny-random-GPT2LMHeadModel", {}, {}, {}, {}), ("hf-internal-testing/tiny-random-BloomForCausalLM", {}, {}, {}, {}), ("hf-internal-testing/tiny-random-gpt_neo", {}, {}, {}, {}), + ("hf-internal-testing/tiny-random-GPTJForCausalLM", {}, {}, {}, {}), ] @@ -53,7 +54,7 @@ class PeftModelTester(unittest.TestCase, PeftTestMixin): We use parametrized.expand for debugging purposes to test each model individually. """ - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_MODELS_TO_TEST)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST)) def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_model_attr(model_id, config_cls, config_kwargs) @@ -110,7 +111,7 @@ def make_inputs_require_grad(module, input, output): self.assertTrue(dummy_output.requires_grad) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_MODELS_TO_TEST)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST)) def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_prepare_for_training(model_id, config_cls, config_kwargs) @@ -156,6 +157,6 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs): # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_MODELS_TO_TEST)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST)) def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) From d634bd52143961c0e4adfee745293e7e48f3b9da Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 28 Mar 2023 13:20:00 +0000 Subject: [PATCH 3/7] add generate tests --- tests/test_peft_model.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index 8d3bc5b042..ff5ec17c95 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -160,3 +160,26 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs): @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST)) def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) + + def _test_generate(self, model_id, config_cls, config_kwargs): + model = AutoModelForCausalLM.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + + # check if `generate` works + _ = model.generate(input_ids=input_ids, attention_mask=attention_mask) + + with self.assertRaises(TypeError): + # check if `generate` raises an error if no positional arguments are passed + _ = model.generate(input_ids, attention_mask=attention_mask) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST)) + def test_generate(self, test_name, model_id, config_cls, config_kwargs): + self._test_generate(model_id, config_cls, config_kwargs) From 29a2ff5457de4088aa236331dd3f5716ef232477 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 28 Mar 2023 13:52:24 +0000 Subject: [PATCH 4/7] make style --- src/peft/peft_model.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index f73a66ac08..c552713a5f 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -582,7 +582,10 @@ def generate(self, **kwargs): else: if "input_ids" not in kwargs: raise ValueError("input_ids must be provided for Peft model generation") - if kwargs.get("attention_mask", None) is not None: + if ( + kwargs.get("attention_mask", None) is not None + and self.peft_config.peft_type == PeftType.PROMPT_TUNING + ): # concat prompt attention mask prefix_attention_mask = torch.ones( kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens @@ -611,6 +614,14 @@ def generate(self, **kwargs): def prepare_inputs_for_generation(self, *args, **kwargs): model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) if isinstance(self.peft_config, PromptLearningConfig): + if self.peft_config.peft_type == PeftType.PREFIX_TUNING: + prefix_attention_mask = torch.ones( + model_kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens + ).to(model_kwargs["input_ids"].device) + model_kwargs["attention_mask"] = torch.cat( + (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1 + ) + if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING: past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0]) model_kwargs["past_key_values"] = past_key_values From b4c9c0f84eb3ec61cd63a6a381279fbdee9dae8a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 28 Mar 2023 14:01:30 +0000 Subject: [PATCH 5/7] fix test --- src/peft/peft_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index c552713a5f..854a0cc0a5 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -582,10 +582,10 @@ def generate(self, **kwargs): else: if "input_ids" not in kwargs: raise ValueError("input_ids must be provided for Peft model generation") - if ( - kwargs.get("attention_mask", None) is not None - and self.peft_config.peft_type == PeftType.PROMPT_TUNING - ): + if kwargs.get("attention_mask", None) is not None and self.peft_config.peft_type in [ + PeftType.PROMPT_TUNING, + PeftType.P_TUNING, + ]: # concat prompt attention mask prefix_attention_mask = torch.ones( kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens From a3a102255b804d914993968741098ab0fbfa2fcc Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 28 Mar 2023 14:10:58 +0000 Subject: [PATCH 6/7] add -n --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 61549dbd01..03ae1e08df 100644 --- a/Makefile +++ b/Makefile @@ -17,4 +17,4 @@ style: doc-builder style src tests --max_len 119 test: - pytest tests/ \ No newline at end of file + pytest -n 3 tests/ \ No newline at end of file From 71c4b3ff62700920085b0874cada6abac7be3d57 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 29 Mar 2023 12:06:53 +0000 Subject: [PATCH 7/7] skip llama --- tests/test_peft_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index ff5ec17c95..275a3cf197 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -32,8 +32,8 @@ # This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs PEFT_DECODER_MODELS_TO_TEST = [ + # ("HuggingFaceM4/tiny-random-LlamaForCausalLM", {}, {}, {}, {}), wait until the next `transformers` release ("hf-internal-testing/tiny-random-OPTForCausalLM", {}, {}, {}, {}), - ("HuggingFaceM4/tiny-random-LlamaForCausalLM", {}, {}, {}, {}), ("hf-internal-testing/tiny-random-GPTNeoXForCausalLM", {}, {}, {}, {}), ("hf-internal-testing/tiny-random-GPT2LMHeadModel", {}, {}, {}, {}), ("hf-internal-testing/tiny-random-BloomForCausalLM", {}, {}, {}, {}),