Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Mistral Model in Llama-Adapter Method #1433

Merged
merged 9 commits into from
Mar 12, 2024
7 changes: 7 additions & 0 deletions src/peft/tuners/adaption_prompt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def is_adaption_prompt(self) -> bool:
v_proj_layer="v_proj",
o_proj_layer="o_proj",
),
"mistral": ModelTypeConfig( # same for Mistral,
PrakharSaxena24 marked this conversation as resolved.
Show resolved Hide resolved
compute_query_states=llama_compute_query_states,
target_modules="self_attn",
k_proj_layer="k_proj",
v_proj_layer="v_proj",
o_proj_layer="o_proj",
),
}


Expand Down
21 changes: 19 additions & 2 deletions src/peft/tuners/adaption_prompt/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,25 @@ def forward(self, **kwargs):
k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer
v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer
o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer
factor = (
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
self.model.k_proj.in_features // self.model.k_proj.out_features
) # Mistral has different input and output dimension for k_proj and v_proj layers

if k_proj_layer == v_proj_layer:
_, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2)
else:
key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
value = getattr(self.model, v_proj_layer)(self.adaption_prompt)
# (bsz, num_heads, adapter_len, head_dim)

adapter_k = (
key.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
key.view(1, self.adapter_len, self.model.num_heads, (self.model.head_dim // factor))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The head dim shouldn't change but the number of heads should be reduced in GQA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! Thanks a lot, this seems correct.
Will edit this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think I will need to do the same in utils.py

PrakharSaxena24 marked this conversation as resolved.
Show resolved Hide resolved
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
# (bsz, num_heads, adapter_len, head_dim)
adapter_v = (
value.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
value.view(1, self.adapter_len, self.model.num_heads, (self.model.head_dim // factor))
PrakharSaxena24 marked this conversation as resolved.
Show resolved Hide resolved
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
Expand All @@ -100,6 +104,15 @@ def forward(self, **kwargs):
query_states = compute_query_states(model=self.model, **kwargs)

previous_dtype = query_states.dtype

# Reshape and average the extra tensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to reshape and avg query states as the above key shape is (bsz, adapter_seq_len, num_kv_heads, head_dim), the value shape is (bsz, adapter_seq_len, num_kv_heads, head_dim) and query shape is (bsz, adapter_seq_len, num_heads, head_dim). Now, you would need to repeat the num_kv_heads to match num_heads as done in https://github.com/huggingface/transformers/blob/1c31b7aa3bb4e7ef24c77596d2a76f45a770159f/src/transformers/models/mistral/modeling_mistral.py#L193. After that the attn computation is same as normal MHA case.

Copy link
Contributor Author

@PrakharSaxena24 PrakharSaxena24 Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, so rather than repeating the adapter output, I should repeat adapter_k and adapter_v.
adapter_k = torch.repeat_interleave( adapter_k, repeats=factor, dim=1 )
adapter_v = torch.repeat_interleave( adapter_v, repeats=factor, dim=1 )
as the key, value shape is (bsz, num_kv_heads, adapter_seq_len, head_dim), (dim 1 for num_kv_heads)
Does this makes sense?

query_states_reshaped = query_states.reshape(
bsz, self.model.num_heads, -1, (self.model.head_dim // factor), factor
)

# Take the mean along the last dimension to get [bsz, 32, X, 32]
PrakharSaxena24 marked this conversation as resolved.
Show resolved Hide resolved
query_states = query_states_reshaped.mean(dim=-1)

# (bsz, num_heads, q_len, adapter_len)
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(
self.model.head_dim
Expand All @@ -109,6 +122,10 @@ def forward(self, **kwargs):
scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype)
# (bsz, q_len, num_heads * head_dim)
adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1)

adapter_output = torch.repeat_interleave(
adapter_output, repeats=factor, dim=2
) # https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181
PrakharSaxena24 marked this conversation as resolved.
Show resolved Hide resolved
# (bsz, q_len, hidden_size)
if o_proj_layer is not None:
adapter_output = getattr(self.model, o_proj_layer)(adapter_output)
Expand Down
6 changes: 5 additions & 1 deletion src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)

factor = model.k_proj.in_features // model.k_proj.out_features
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, (model.head_dim // factor)).transpose(1, 2)
)

seq_len = q_len
if past_key_value is not None:
Expand Down