Skip to content

Commit 33be8ea

Browse files
FIX: Multitask prompt tuning with other tuning init (huggingface#1144)
Resolves huggingface#1082. Also, adding tests for prompt_tuning_init != RANDOM. --------- Co-authored-by: Mayank Mishra <32954280+mayank31398@users.noreply.github.com>
1 parent 16da7ec commit 33be8ea

File tree

3 files changed

+68
-5
lines changed

3 files changed

+68
-5
lines changed

src/peft/peft_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,7 @@ def generate(self, *args, **kwargs):
11551155
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
11561156
return outputs
11571157

1158-
def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs):
1158+
def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] = None, **kwargs):
11591159
peft_config = self.active_peft_config
11601160
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
11611161

src/peft/tuners/multitask_prompt_tuning/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ def __init__(self, config: MultitaskPromptTuningConfig, word_embeddings):
6666
"init method"
6767
)
6868

69+
# TODO: There should be an option for safetensors
6970
state_dict: dict = torch.load(
7071
config.prompt_tuning_init_state_dict_path,
71-
map_location=word_embeddings.device,
72+
map_location=word_embeddings.weight.device,
7273
)
7374

7475
if config.prompt_tuning_init in [

tests/test_multitask_prompt_tuning.py

+65-3
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
import tempfile
1818
from unittest import TestCase
1919

20+
import pytest
2021
import torch
22+
from parameterized import parameterized
2123
from torch.testing import assert_close
2224

2325
from peft.mapping import get_peft_model
2426
from peft.peft_model import PeftModel
25-
from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig
26-
from peft.utils.other import prepare_model_for_int8_training
27+
from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig, MultitaskPromptTuningInit
28+
from peft.utils.other import WEIGHTS_NAME, prepare_model_for_int8_training
2729
from peft.utils.save_and_load import get_peft_model_state_dict
2830
from tests.testing_common import PeftCommonTester
2931

@@ -73,7 +75,9 @@ def _create_multitask_prompt_tuning_config(cls) -> MultitaskPromptTuningConfig:
7375
task_type="CAUSAL_LM",
7476
num_virtual_tokens=50,
7577
num_tasks=3,
76-
prompt_tuning_init_text="classify the following into either positive or negative, or entailment, neutral or contradiction:",
78+
prompt_tuning_init_text=(
79+
"classify the following into either positive or negative, or entailment, neutral or contradiction:"
80+
),
7781
)
7882

7983
def test_prepare_for_training(self) -> None:
@@ -240,3 +244,61 @@ def test_bf16_inference(self) -> None:
240244
mpt = get_peft_model(original, self._create_multitask_prompt_tuning_config())
241245
mpt = mpt.to(self.torch_device)
242246
_ = mpt.generate(input_ids=input_ids, task_ids=task_ids)
247+
248+
def test_generate_text_with_random_init(self) -> None:
249+
model = LlamaForCausalLM(self._create_test_llama_config())
250+
251+
config = self._create_multitask_prompt_tuning_config()
252+
config.prompt_tuning_init = MultitaskPromptTuningInit.RANDOM
253+
254+
model = get_peft_model(model, config)
255+
model = model.to(self.torch_device)
256+
257+
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
258+
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
259+
task_ids = torch.LongTensor([0]).to(self.torch_device)
260+
261+
# check if `generate` works
262+
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)
263+
264+
with pytest.raises(ValueError):
265+
# check if `generate` raises an error if task_ids are not passed
266+
_ = model.generate(input_ids, attention_mask=attention_mask)
267+
268+
@parameterized.expand(
269+
[
270+
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
271+
MultitaskPromptTuningInit.EXACT_SOURCE_TASK,
272+
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
273+
],
274+
)
275+
def test_generate_text_with_other_init(self, prompt_tuning_init) -> None:
276+
with tempfile.TemporaryDirectory() as tmp_dirname:
277+
model = LlamaForCausalLM(self._create_test_llama_config())
278+
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
279+
model.save_pretrained(tmp_dirname, safe_serialization=False) # bc torch.load is used
280+
281+
config = MultitaskPromptTuningConfig(
282+
task_type="CAUSAL_LM",
283+
num_virtual_tokens=50,
284+
num_tasks=1,
285+
prompt_tuning_init_text=(
286+
"classify the following into either positive or negative, or entailment, neutral or contradiction:"
287+
),
288+
prompt_tuning_init=prompt_tuning_init,
289+
prompt_tuning_init_state_dict_path=os.path.join(tmp_dirname, WEIGHTS_NAME),
290+
)
291+
model = LlamaForCausalLM(self._create_test_llama_config())
292+
model = get_peft_model(model, config)
293+
model = model.to(self.torch_device)
294+
295+
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
296+
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
297+
task_ids = torch.LongTensor([0]).to(self.torch_device)
298+
299+
# check if `generate` works
300+
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)
301+
302+
with pytest.raises(ValueError):
303+
# check if `generate` raises an error if task_ids are not passed
304+
_ = model.generate(input_ids, attention_mask=attention_mask)

0 commit comments

Comments
 (0)