Skip to content

Commit

Permalink
fix: lora dropout applied to all models (pytorch#995)
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and maximegmd committed Jul 13, 2024
1 parent f4c1d70 commit 2352df3
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 4 deletions.
5 changes: 5 additions & 0 deletions torchtune/models/gemma/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def lora_gemma(
lora_rank=lora_rank,
lora_alpha=lora_alpha,
quantize_base=quantize_base,
lora_dropout=lora_dropout,
)
else:
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
Expand Down Expand Up @@ -278,6 +279,7 @@ def lora_gemma_self_attention(
num_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "q_proj" in lora_modules
Expand All @@ -289,6 +291,7 @@ def lora_gemma_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "k_proj" in lora_modules
Expand All @@ -300,6 +303,7 @@ def lora_gemma_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "v_proj" in lora_modules
Expand All @@ -311,6 +315,7 @@ def lora_gemma_self_attention(
embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
Expand Down
7 changes: 6 additions & 1 deletion torchtune/models/llama2/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def lora_llama2(
lora_rank=lora_rank,
lora_alpha=lora_alpha,
quantize_base=quantize_base,
lora_dropout=lora_dropout,
)
else:
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim)
Expand All @@ -233,7 +234,7 @@ def lora_llama2(

# TODO: quantize_base is not applied to final output_proj currently.
output_proj = (
LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha)
LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
Expand Down Expand Up @@ -323,6 +324,7 @@ def lora_llama2_self_attention(
num_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "q_proj" in lora_modules
Expand All @@ -334,6 +336,7 @@ def lora_llama2_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "k_proj" in lora_modules
Expand All @@ -345,6 +348,7 @@ def lora_llama2_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "v_proj" in lora_modules
Expand All @@ -356,6 +360,7 @@ def lora_llama2_self_attention(
embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
Expand Down
7 changes: 6 additions & 1 deletion torchtune/models/llama3/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def lora_llama3(
lora_rank=lora_rank,
lora_alpha=lora_alpha,
quantize_base=quantize_base,
lora_dropout=lora_dropout,
)
else:
mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim)
Expand All @@ -230,7 +231,7 @@ def lora_llama3(

# TODO: quantize_base is not applied to final output_proj currently.
output_proj = (
LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha)
LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
Expand Down Expand Up @@ -315,6 +316,7 @@ def lora_llama3_self_attention(
num_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "q_proj" in lora_modules
Expand All @@ -326,6 +328,7 @@ def lora_llama3_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "k_proj" in lora_modules
Expand All @@ -337,6 +340,7 @@ def lora_llama3_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "v_proj" in lora_modules
Expand All @@ -348,6 +352,7 @@ def lora_llama3_self_attention(
embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
Expand Down
8 changes: 7 additions & 1 deletion torchtune/models/mistral/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def lora_mistral(
hidden_dim=intermediate_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
quantize_base=quantize_base,
)
else:
Expand Down Expand Up @@ -311,6 +312,7 @@ def lora_mistral_self_attention(
num_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "q_proj" in lora_modules
Expand All @@ -322,6 +324,7 @@ def lora_mistral_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "k_proj" in lora_modules
Expand All @@ -333,6 +336,7 @@ def lora_mistral_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "v_proj" in lora_modules
Expand All @@ -344,6 +348,7 @@ def lora_mistral_self_attention(
embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
Expand Down Expand Up @@ -568,6 +573,7 @@ def lora_mistral_classifier(
hidden_dim=intermediate_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
quantize_base=quantize_base,
)
else:
Expand All @@ -584,7 +590,7 @@ def lora_mistral_classifier(

# TODO: quantize_base is not applied to final output_proj currently.
output_proj = (
LoRALinear(embed_dim, num_classes, rank=lora_rank, alpha=lora_alpha)
LoRALinear(embed_dim, num_classes, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
if apply_lora_to_output
else nn.Linear(embed_dim, num_classes, bias=False)
)
Expand Down
7 changes: 6 additions & 1 deletion torchtune/models/phi3/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def lora_phi3(
lora_rank=lora_rank,
lora_alpha=lora_alpha,
quantize_base=quantize_base,
lora_dropout=lora_dropout,
)
else:
mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
Expand All @@ -216,7 +217,7 @@ def lora_phi3(

# TODO: quantize_base is not applied to final output_proj currently.
output_proj = (
LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha)
LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
Expand Down Expand Up @@ -303,6 +304,7 @@ def lora_phi3_self_attention(
num_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "q_proj" in lora_modules
Expand All @@ -314,6 +316,7 @@ def lora_phi3_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "k_proj" in lora_modules
Expand All @@ -325,6 +328,7 @@ def lora_phi3_self_attention(
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "v_proj" in lora_modules
Expand All @@ -336,6 +340,7 @@ def lora_phi3_self_attention(
embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
Expand Down

0 comments on commit 2352df3

Please sign in to comment.