Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default IA3 target modules for Mixtral #1376

Merged
merged 4 commits into from
Feb 15, 2024

Conversation

arnavgarg1
Copy link
Contributor

@arnavgarg1 arnavgarg1 commented Jan 19, 2024

Here's the mixtral model architecture with my proposed IA3 target module mapping:

PeftModelForCausalLM(                                                                                                                                       
  (base_model): IA3Model(                                                                                                                                   
    (model): MixtralForCausalLM(                                                                                                                            
      (model): MixtralModel(                                                                                                                                
        (embed_tokens): Embedding(32000, 4096)                                                                                                              
        (layers): ModuleList(                                                                                                                               
          (0-13): 14 x MixtralDecoderLayer(                                                                                                                 
            (self_attn): MixtralAttention(                                                                                                                  
              (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)                                                                         
              (k_proj): ia3.Linear4bit(                                                                                                                     
                (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)                                                                   
                (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1024x1 (cuda:0)])                                  
              )                                                                                                                                             
              (v_proj): ia3.Linear4bit(                                                                                                                     
                (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)                                                                   
                (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1024x1 (cuda:0)])                                  
              )                                                                                                                                             
              (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)                                                                         
              (rotary_emb): MixtralRotaryEmbedding()                                                                                                        
            )                                                                                                                                               
            (block_sparse_moe): MixtralSparseMoeBlock(                                                                                                      
              (gate): Linear4bit(in_features=4096, out_features=8, bias=False)                                                                              
              (experts): ModuleList(                                                                                                                        
                (0-7): 8 x MixtralBLockSparseTop2MLP(                                                                                                       
                  (w1): ia3.Linear4bit(                                                                                                                     
                    (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)                                                              
                    (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1x4096 (cuda:0)])                              
                  )                                                                                                                                         
                  (w2): ia3.Linear4bit(                                                                                                                     
                    (base_layer): Linear4bit(in_features=14336, out_features=4096, bias=False)                                                              
                    (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1x14336 (cuda:0)])                             
                  )                                                                                                                                         
                  (w3): ia3.Linear4bit(                                                                                                                     
                    (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)                                                              
                    (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1x4096 (cuda:0)])                              
                  )                                                                                                                                         
                  (act_fn): SiLU()                                                                                                                          
                )                                                                                                                                           
              )                                                                                                                                             
            )                                                                                                                                               
            (input_layernorm): MixtralRMSNorm()                                                                                                             
            (post_attention_layernorm): MixtralRMSNorm()
          )
          (14-31): 18 x MixtralDecoderLayer(
            (self_attn): MixtralAttention(
              (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (k_proj): ia3.Linear4bit( 
                (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
                (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1024x1 (cuda:1)])
              )
              (v_proj): ia3.Linear4bit( 
                (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
                (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1024x1 (cuda:1)])
              )
              (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): MixtralRotaryEmbedding()
            )
            (block_sparse_moe): MixtralSparseMoeBlock(
              (gate): Linear4bit(in_features=4096, out_features=8, bias=False) 
              (experts): ModuleList(
                (0-7): 8 x MixtralBLockSparseTop2MLP(
                  (w1): ia3.Linear4bit( 
                    (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)
                    (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1x4096 (cuda:1)])
                  )
                  (w2): ia3.Linear4bit( 
                    (base_layer): Linear4bit(in_features=14336, out_features=4096, bias=False)
                    (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1x14336 (cuda:1)])
                  )
                  (w3): ia3.Linear4bit( 
                    (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)
                    (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1x4096 (cuda:1)])
                  )
                  (act_fn): SiLU()
                )
              )
            )
            (input_layernorm): MixtralRMSNorm()
            (post_attention_layernorm): MixtralRMSNorm()
          )
        )
        (norm): MixtralRMSNorm()
      )
      (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
    )
  )
)

These are the number of trainable parameters:

trainable params: 11,665,408 || all params: 46,714,458,112 || trainable%: 0.024971729249286513

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arnavgarg1 Thanks for the contribution!
I did not noticed your PR until now .. I made #1380 few days ago that adds mixtral in the LoRA mapping. Would you be happy to convert this PR to a PR that adds mixtral to IA3 mapping instead?

@arnavgarg1
Copy link
Contributor Author

@younesbelkada Yes!

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @arnavgarg1, as Younes mentioned, please update the PR to add target modules for Mixtral when using IA3.

@arnavgarg1 arnavgarg1 changed the title Add default LoRA target modules for Mixtral Add default IA3 target modules for Mixtral Jan 29, 2024
@arnavgarg1
Copy link
Contributor Author

@pacman100 @younesbelkada Just updated with IA3 instead! I'm also going to add a separate PR for Phi with IA3 right now.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@arnavgarg1
Copy link
Contributor Author

No problem!

@@ -93,6 +93,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"gpt_bigcode": ["c_attn", "mlp.c_proj"],
"llama": ["k_proj", "v_proj", "down_proj"],
"mistral": ["k_proj", "v_proj", "down_proj"],
"mixtral": ["k_proj", "v_proj", "w1", "w2", "w3"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pacman100 Would this be k_proj, v_proj and w2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the ffn layer should just be w2 as per the IA3 paper.

@arnavgarg1
Copy link
Contributor Author

Wanted to see what's left here as next steps @pacman100 @younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! wdyt @pacman100 ?

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @arnavgarg1, thank you for adding IA3 target modules for Mixtral! Please see the comment on and post addressing that we can merge this.

@@ -93,6 +93,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"gpt_bigcode": ["c_attn", "mlp.c_proj"],
"llama": ["k_proj", "v_proj", "down_proj"],
"mistral": ["k_proj", "v_proj", "down_proj"],
"mixtral": ["k_proj", "v_proj", "w1", "w2", "w3"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the ffn layer should just be w2 as per the IA3 paper.

@@ -115,6 +116,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"gpt_bigcode": ["mlp.c_proj"],
"llama": ["down_proj"],
"mistral": ["down_proj"],
"mixtral": ["w1", "w2", "w3"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed above, it should only have w2

@arnavgarg1 arnavgarg1 requested a review from pacman100 February 7, 2024 18:09
@arnavgarg1
Copy link
Contributor Author

Thanks @pacman100 ! Just updated

@arnavgarg1
Copy link
Contributor Author

Is it good to merge?

@pacman100
Copy link
Contributor

Thank you @arnavgarg1! ✨

@arnavgarg1
Copy link
Contributor Author

Thanks!

@younesbelkada younesbelkada merged commit 83de1af into huggingface:main Feb 15, 2024
14 checks passed
BenjaminBossan pushed a commit to BenjaminBossan/peft that referenced this pull request Mar 14, 2024
* Add default LoRA target modules for Mixtral

* Add IA3 modules for Mixtral

* Address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants