Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Mistral Model in Llama-Adapter Method #1433

Merged
merged 9 commits into from
Mar 12, 2024
7 changes: 7 additions & 0 deletions src/peft/tuners/adaption_prompt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def is_adaption_prompt(self) -> bool:
v_proj_layer="v_proj",
o_proj_layer="o_proj",
),
"mistral": ModelTypeConfig( # same as llama,
compute_query_states=llama_compute_query_states,
target_modules="self_attn",
k_proj_layer="k_proj",
v_proj_layer="v_proj",
o_proj_layer="o_proj",
),
}


Expand Down
18 changes: 13 additions & 5 deletions src/peft/tuners/adaption_prompt/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,38 @@ def forward(self, **kwargs):
k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer
v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer
o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer
factor = (
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
self.model.k_proj.in_features // self.model.k_proj.out_features
) # Mistral has different input and output dimension for k_proj and v_proj layers

if k_proj_layer == v_proj_layer:
_, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2)
else:
key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
value = getattr(self.model, v_proj_layer)(self.adaption_prompt)
# (bsz, num_heads, adapter_len, head_dim)

# (bsz, num_key_value_heads, adapter_len, head_dim)
adapter_k = (
key.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
# (bsz, num_heads, adapter_len, head_dim)
adapter_v = (
value.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)

# Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181
# (bsz, num_heads, adapter_len, head_dim)
adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1)
adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1)
# Recompute query states.
compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states
# (bsz, num_heads, q_len, head_dim)
query_states = compute_query_states(model=self.model, **kwargs)

previous_dtype = query_states.dtype

# (bsz, num_heads, q_len, adapter_len)
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(
self.model.head_dim
Expand All @@ -108,6 +115,7 @@ def forward(self, **kwargs):
scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype)
# (bsz, q_len, num_heads * head_dim)
adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1)

# (bsz, q_len, hidden_size)
if o_proj_layer is not None:
adapter_output = getattr(self.model, o_proj_layer)(adapter_output)
Expand Down
7 changes: 6 additions & 1 deletion src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)

factor = model.k_proj.in_features // model.k_proj.out_features
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
)

seq_len = q_len

if past_key_value is not None:
Expand Down
Loading
Loading