From 8bbfbb723eae8c6816d80901612c46ebdcabf36b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 17 Nov 2023 12:53:33 +0100 Subject: [PATCH 1/3] FIX Multitask prompt tuning with other tuning init This is WIP. I attempted to fix #1082. While adding tests for the bug, I discovered that I could not make prompt_tuning_init != RANDOM to work. Maybe I'm using it wrong, but I'm not sure what to change. --- src/peft/peft_model.py | 2 +- .../tuners/multitask_prompt_tuning/model.py | 3 +- tests/test_multitask_prompt_tuning.py | 47 +++++++++++++++++-- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index e0f0977e28..c45a0a0cfd 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1065,7 +1065,7 @@ def generate(self, **kwargs): self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation return outputs - def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs): + def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] = None, **kwargs): peft_config = self.active_peft_config model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) if peft_config.is_prompt_learning: diff --git a/src/peft/tuners/multitask_prompt_tuning/model.py b/src/peft/tuners/multitask_prompt_tuning/model.py index 72c7feda0e..eeaa423f38 100644 --- a/src/peft/tuners/multitask_prompt_tuning/model.py +++ b/src/peft/tuners/multitask_prompt_tuning/model.py @@ -67,9 +67,10 @@ def __init__(self, config: MultitaskPromptTuningConfig, word_embeddings): "init method" ) + # TODO: There should be an option for safetensors state_dict: dict = torch.load( config.prompt_tuning_init_state_dict_path, - map_location=word_embeddings.device, + map_location=word_embeddings.weight.device, ) if config.prompt_tuning_init in [ diff --git a/tests/test_multitask_prompt_tuning.py b/tests/test_multitask_prompt_tuning.py index 9aa6b8d7d9..cbdf121534 100644 --- a/tests/test_multitask_prompt_tuning.py +++ b/tests/test_multitask_prompt_tuning.py @@ -19,12 +19,13 @@ from unittest import TestCase import torch +from parameterized import parameterized from torch.testing import assert_close from peft.mapping import get_peft_model from peft.peft_model import PeftModel -from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig -from peft.utils.other import prepare_model_for_int8_training +from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig, MultitaskPromptTuningInit +from peft.utils.other import WEIGHTS_NAME, prepare_model_for_int8_training from peft.utils.save_and_load import get_peft_model_state_dict from tests.testing_common import PeftCommonTester @@ -74,7 +75,9 @@ def _create_multitask_prompt_tuning_config(cls) -> MultitaskPromptTuningConfig: task_type="CAUSAL_LM", num_virtual_tokens=50, num_tasks=3, - prompt_tuning_init_text="classify the following into either positive or negative, or entailment, neutral or contradiction:", + prompt_tuning_init_text=( + "classify the following into either positive or negative, or entailment, neutral or contradiction:" + ), ) def test_prepare_for_training(self) -> None: @@ -245,3 +248,41 @@ def test_bf16_inference(self) -> None: mpt = get_peft_model(original, self._create_multitask_prompt_tuning_config()) mpt = mpt.to(self.torch_device) _ = mpt.generate(input_ids=input_ids, task_ids=task_ids) + + @parameterized.expand( + [ + MultitaskPromptTuningInit.TEXT, + MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, + MultitaskPromptTuningInit.EXACT_SOURCE_TASK, + MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, + ], + ) + def test_generate_text(self, prompt_tuning_init) -> None: + with tempfile.TemporaryDirectory() as tmp_dirname: + model = LlamaForCausalLM(self._create_test_llama_config()) + model = get_peft_model(model, self._create_multitask_prompt_tuning_config()) + model.save_pretrained(tmp_dirname, safe_serialization=False) # bc torch.load is used + + config = MultitaskPromptTuningConfig( + task_type="CAUSAL_LM", + num_virtual_tokens=50, + num_tasks=1, + prompt_tuning_init_text=( + "classify the following into either positive or negative, or entailment, neutral or contradiction:" + ), + prompt_tuning_init=prompt_tuning_init, + prompt_tuning_init_state_dict_path=os.path.join(tmp_dirname, WEIGHTS_NAME), + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + task_ids = torch.LongTensor([0]).to(self.torch_device) + + # check if `generate` works + _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids) + + with self.assertRaises(TypeError): + # check if `generate` raises an error if no positional arguments are passed + _ = model.generate(input_ids, attention_mask=attention_mask) From 7d88433bcd76ab552e3b03f6b335d63c3f5adeb9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra <32954280+mayank31398@users.noreply.github.com> Date: Fri, 9 Feb 2024 05:24:18 -0500 Subject: [PATCH 2/3] Fix multitask prompt tuning other inits (#8) --- tests/test_multitask_prompt_tuning.py | 28 +++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/test_multitask_prompt_tuning.py b/tests/test_multitask_prompt_tuning.py index b272c53c22..bdd2e1138f 100644 --- a/tests/test_multitask_prompt_tuning.py +++ b/tests/test_multitask_prompt_tuning.py @@ -249,15 +249,34 @@ def test_bf16_inference(self) -> None: mpt = mpt.to(self.torch_device) _ = mpt.generate(input_ids=input_ids, task_ids=task_ids) + def test_generate_text_with_random_init(self) -> None: + model = LlamaForCausalLM(self._create_test_llama_config()) + + config = self._create_multitask_prompt_tuning_config() + config.prompt_tuning_init = MultitaskPromptTuningInit.RANDOM + + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) + attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + task_ids = torch.LongTensor([0]).to(self.torch_device) + + # check if `generate` works + _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids) + + with self.assertRaises(ValueError): + # check if `generate` raises an error if task_ids are not passed + _ = model.generate(input_ids, attention_mask=attention_mask) + @parameterized.expand( [ - MultitaskPromptTuningInit.TEXT, MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, MultitaskPromptTuningInit.EXACT_SOURCE_TASK, MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, ], ) - def test_generate_text(self, prompt_tuning_init) -> None: + def test_generate_text_with_other_init(self, prompt_tuning_init) -> None: with tempfile.TemporaryDirectory() as tmp_dirname: model = LlamaForCausalLM(self._create_test_llama_config()) model = get_peft_model(model, self._create_multitask_prompt_tuning_config()) @@ -273,6 +292,7 @@ def test_generate_text(self, prompt_tuning_init) -> None: prompt_tuning_init=prompt_tuning_init, prompt_tuning_init_state_dict_path=os.path.join(tmp_dirname, WEIGHTS_NAME), ) + model = LlamaForCausalLM(self._create_test_llama_config()) model = get_peft_model(model, config) model = model.to(self.torch_device) @@ -283,6 +303,6 @@ def test_generate_text(self, prompt_tuning_init) -> None: # check if `generate` works _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids) - with self.assertRaises(TypeError): - # check if `generate` raises an error if no positional arguments are passed + with self.assertRaises(ValueError): + # check if `generate` raises an error if task_ids are not passed _ = model.generate(input_ids, attention_mask=attention_mask) From 4e6f3803c97c693a32b22ab64df3995dacb2b196 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 19 Feb 2024 12:37:55 +0100 Subject: [PATCH 3/3] Use assert (pytest style) --- tests/test_multitask_prompt_tuning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_multitask_prompt_tuning.py b/tests/test_multitask_prompt_tuning.py index bad0f7c1f8..99b7519b37 100644 --- a/tests/test_multitask_prompt_tuning.py +++ b/tests/test_multitask_prompt_tuning.py @@ -17,6 +17,7 @@ import tempfile from unittest import TestCase +import pytest import torch from parameterized import parameterized from torch.testing import assert_close @@ -260,7 +261,7 @@ def test_generate_text_with_random_init(self) -> None: # check if `generate` works _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # check if `generate` raises an error if task_ids are not passed _ = model.generate(input_ids, attention_mask=attention_mask) @@ -298,6 +299,6 @@ def test_generate_text_with_other_init(self, prompt_tuning_init) -> None: # check if `generate` works _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # check if `generate` raises an error if task_ids are not passed _ = model.generate(input_ids, attention_mask=attention_mask)