Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: fix for dimension mismatch when using adaption prompt in Llama2 #1451

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/peft/tuners/adaption_prompt/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import repeat_kv

from .config import TRANSFORMERS_MODEL_CONFIG

Expand Down Expand Up @@ -80,15 +81,15 @@ def forward(self, **kwargs):
else:
key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
value = getattr(self.model, v_proj_layer)(self.adaption_prompt)
# (bsz, num_heads, adapter_len, head_dim)
# (bsz, num_kv_heads, adapter_len, head_dim)
adapter_k = (
key.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
key.view(1, self.adapter_len, self.model.num_key_value_heads, self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
# (bsz, num_heads, adapter_len, head_dim)
# (bsz, num_kv_heads, adapter_len, head_dim)
adapter_v = (
value.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
value.view(1, self.adapter_len, self.model.num_key_value_heads, self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
Expand All @@ -99,6 +100,11 @@ def forward(self, **kwargs):
query_states = compute_query_states(model=self.model, **kwargs)

previous_dtype = query_states.dtype

# repeat kv for GQA
adapter_k = repeat_kv(adapter_k, self.model.num_key_value_groups)
adapter_v = repeat_kv(adapter_v, self.model.num_key_value_groups)

# (bsz, num_heads, q_len, adapter_len)
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(
self.model.head_dim
Expand Down
4 changes: 3 additions & 1 deletion src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, model.num_key_value_heads, model.head_dim).transpose(1, 2)
)

seq_len = q_len
if past_key_value is not None:
Expand Down
49 changes: 49 additions & 0 deletions tests/test_adaption_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from peft.peft_model import PeftModel
from peft.tuners.adaption_prompt import AdaptionPromptConfig
from peft.utils.other import prepare_model_for_int8_training
from peft.utils.peft_types import TaskType
from peft.utils.save_and_load import get_peft_model_state_dict
from tests.testing_common import PeftCommonTester

Expand Down Expand Up @@ -439,3 +440,51 @@ def test_disable_adapter(self):
with model.disable_adapter():
output_peft_disabled = model(dummy_input).logits
self.assertTrue(torch.allclose(output_before, output_peft_disabled))

def test_llama_gqa(self) -> None:
"""Test that AdaptionPrompt works when Llama2 using group query attention(GQA) or not."""
# test for llama with gqa
model_config = LlamaConfig(
vocab_size=16,
hidden_size=8,
intermediate_size=8,
num_hidden_layers=8,
num_attention_heads=4,
num_key_value_heads=2,
)
model = LlamaForCausalLM(config=model_config)
adaption_config = AdaptionPromptConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
target_modules="self_attn",
adapter_len=16,
adapter_layers=4,
)
peft_model = get_peft_model(model, adaption_config)
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]])
_ = peft_model(input_ids=input_ids)
del model
del peft_model

# test for llama without gqa
model_config = LlamaConfig(
vocab_size=16,
hidden_size=8,
intermediate_size=8,
num_hidden_layers=8,
num_attention_heads=4,
num_key_value_heads=None,
)
model = LlamaForCausalLM(config=model_config)
adaption_config = AdaptionPromptConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
target_modules="self_attn",
adapter_len=16,
adapter_layers=4,
)
peft_model = get_peft_model(model, adaption_config)
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]])
_ = peft_model(input_ids=input_ids)
del model
del peft_model
Loading