From 4690b9bb4cba3686cb039388196ecdc6b66a5f86 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 4 Sep 2024 12:28:03 -0700 Subject: [PATCH 01/12] use identity for dropout if 0 --- torchtune/modules/peft/dora.py | 2 +- torchtune/modules/peft/lora.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/peft/dora.py b/torchtune/modules/peft/dora.py index d8ef8016b1..e0a8fe9788 100644 --- a/torchtune/modules/peft/dora.py +++ b/torchtune/modules/peft/dora.py @@ -68,7 +68,7 @@ def __init__( # and disabling it to treat the base model as the reference model self.disabled = False - self.dropout = nn.Dropout(p=dropout) + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) self.magnitude = nn.Parameter(torch.empty(out_dim)) diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 7af24df1f0..57d72af672 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -65,7 +65,7 @@ def __init__( self.register_parameter( "bias", nn.Parameter(bias) if bias is not None else None ) - self.dropout = nn.Dropout(p=dropout) + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) self.merged = False From 17e6d790058064682fcb2c5031f2e6f6b03ab1a7 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 4 Sep 2024 12:28:14 -0700 Subject: [PATCH 02/12] update model builders --- torchtune/models/code_llama2/_model_builders.py | 9 +++++---- torchtune/models/gemma/_model_builders.py | 8 ++++++-- torchtune/models/llama2/_model_builders.py | 13 +++++++------ torchtune/models/llama3/_model_builders.py | 7 +++++-- torchtune/models/llama3_1/_model_builders.py | 17 +++++++++++------ torchtune/models/mistral/_model_builders.py | 8 ++++++-- torchtune/models/phi3/_model_builders.py | 4 +++- torchtune/models/qwen2/_model_builders.py | 9 ++++++--- 8 files changed, 49 insertions(+), 26 deletions(-) diff --git a/torchtune/models/code_llama2/_model_builders.py b/torchtune/models/code_llama2/_model_builders.py index cb7917a3d5..cbec735e6e 100644 --- a/torchtune/models/code_llama2/_model_builders.py +++ b/torchtune/models/code_llama2/_model_builders.py @@ -41,7 +41,7 @@ def lora_code_llama2_7b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -62,7 +62,7 @@ def lora_code_llama2_7b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation - lora_dropout (float): dropout probability for LoRA linear layers. Default: 0.05 + lora_dropout (float): dropout probability for LoRA linear layers. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: @@ -125,7 +125,7 @@ def lora_code_llama2_13b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -212,7 +212,7 @@ def lora_code_llama2_70b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -233,6 +233,7 @@ def lora_code_llama2_70b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights diff --git a/torchtune/models/gemma/_model_builders.py b/torchtune/models/gemma/_model_builders.py index cbbf11f09e..611381b974 100644 --- a/torchtune/models/gemma/_model_builders.py +++ b/torchtune/models/gemma/_model_builders.py @@ -68,6 +68,7 @@ def lora_gemma_2b( apply_lora_to_mlp: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> GemmaTransformerDecoder: @@ -86,6 +87,7 @@ def lora_gemma_2b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights @@ -108,7 +110,7 @@ def lora_gemma_2b( norm_eps=1e-6, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) @@ -150,6 +152,7 @@ def lora_gemma_7b( apply_lora_to_mlp: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> GemmaTransformerDecoder: @@ -168,6 +171,7 @@ def lora_gemma_7b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights @@ -190,7 +194,7 @@ def lora_gemma_7b( norm_eps=1e-6, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index c9e19d1109..356411e498 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -67,7 +67,7 @@ def lora_llama2_7b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -92,8 +92,7 @@ def lora_llama2_7b( use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights - lora_dropout (float): dropout probability for LoRA linear layers. Default: 0.05 - + Returns: TransformerDecoder: Instantiation of Llama2 7B model with LoRA applied """ @@ -153,7 +152,7 @@ def lora_llama2_13b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -238,7 +237,7 @@ def lora_llama2_70b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -259,6 +258,7 @@ def lora_llama2_70b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights @@ -323,7 +323,7 @@ def lora_llama2_reward_7b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -344,6 +344,7 @@ def lora_llama2_reward_7b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: diff --git a/torchtune/models/llama3/_model_builders.py b/torchtune/models/llama3/_model_builders.py index 1be7898bd3..ddfe88d6e7 100644 --- a/torchtune/models/llama3/_model_builders.py +++ b/torchtune/models/llama3/_model_builders.py @@ -115,6 +115,7 @@ def lora_llama3_8b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). @@ -138,7 +139,7 @@ def lora_llama3_8b( rope_base=500000.0, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, quantize_base=quantize_base, use_dora=use_dora, ) @@ -150,6 +151,7 @@ def lora_llama3_70b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.0, quantize_base: bool = False, use_dora: bool = False, ) -> TransformerDecoder: @@ -170,6 +172,7 @@ def lora_llama3_70b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). @@ -193,7 +196,7 @@ def lora_llama3_70b( rope_base=500000.0, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, quantize_base=quantize_base, use_dora=use_dora, ) diff --git a/torchtune/models/llama3_1/_model_builders.py b/torchtune/models/llama3_1/_model_builders.py index 3c02a1fad0..18d8ee109d 100644 --- a/torchtune/models/llama3_1/_model_builders.py +++ b/torchtune/models/llama3_1/_model_builders.py @@ -3,16 +3,13 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +from typing import List from functools import partial from torchtune.models.llama3_1._component_builders import llama3_1, lora_llama3_1 from torchtune.modules import TransformerDecoder -from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.modules.peft import LORA_ATTN_MODULES -from torchtune.modules.tokenizers import parse_hf_tokenizer_json - """ Model builders build specific instantiations using component builders. For example @@ -87,6 +84,7 @@ def lora_llama3_1_8b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -107,6 +105,9 @@ def lora_llama3_1_8b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights Returns: @@ -128,7 +129,7 @@ def lora_llama3_1_8b( rope_base=500000.0, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) @@ -140,6 +141,7 @@ def lora_llama3_1_70b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -160,6 +162,9 @@ def lora_llama3_1_70b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights Returns: @@ -181,7 +186,7 @@ def lora_llama3_1_70b( rope_base=500000.0, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) diff --git a/torchtune/models/mistral/_model_builders.py b/torchtune/models/mistral/_model_builders.py index 1442704632..6a6ef0b622 100644 --- a/torchtune/models/mistral/_model_builders.py +++ b/torchtune/models/mistral/_model_builders.py @@ -73,6 +73,7 @@ def lora_mistral_7b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -89,6 +90,7 @@ def lora_mistral_7b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights @@ -112,7 +114,7 @@ def lora_mistral_7b( rope_base=10_000, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) @@ -157,6 +159,7 @@ def lora_mistral_reward_7b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -173,6 +176,7 @@ def lora_mistral_reward_7b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights @@ -197,7 +201,7 @@ def lora_mistral_reward_7b( rope_base=10_000, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) diff --git a/torchtune/models/phi3/_model_builders.py b/torchtune/models/phi3/_model_builders.py index c8a0cc540b..6b3d0cf6c4 100644 --- a/torchtune/models/phi3/_model_builders.py +++ b/torchtune/models/phi3/_model_builders.py @@ -77,6 +77,7 @@ def lora_phi3_mini( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -97,6 +98,7 @@ def lora_phi3_mini( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model weights @@ -119,7 +121,7 @@ def lora_phi3_mini( norm_eps=1e-5, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) diff --git a/torchtune/models/qwen2/_model_builders.py b/torchtune/models/qwen2/_model_builders.py index daa1a8057e..9d4a6619c4 100644 --- a/torchtune/models/qwen2/_model_builders.py +++ b/torchtune/models/qwen2/_model_builders.py @@ -134,7 +134,7 @@ def lora_qwen2_7b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: @@ -155,6 +155,7 @@ def lora_qwen2_7b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_droppout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: @@ -187,7 +188,7 @@ def lora_qwen2_0_5b( apply_lora_to_mlp: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TiedEmbeddingTransformerDecoder: @@ -206,6 +207,7 @@ def lora_qwen2_0_5b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_droppout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: @@ -243,7 +245,7 @@ def lora_qwen2_1_5b( apply_lora_to_mlp: bool = False, lora_rank: int = 8, lora_alpha: float = 16, - lora_dropout: float = 0.05, + lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TiedEmbeddingTransformerDecoder: @@ -262,6 +264,7 @@ def lora_qwen2_1_5b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + lora_droppout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: From b0154b941c2282f9a34b31109e72842df39d72eb Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 4 Sep 2024 12:51:24 -0700 Subject: [PATCH 03/12] add lora dropout to configs --- recipes/configs/code_llama2/7B_lora_single_device.yaml | 1 + recipes/configs/code_llama2/7B_qlora_single_device.yaml | 1 + recipes/configs/dev/llama2/13B_lora_fsdp2.yaml | 1 + recipes/configs/dev/llama2/70B_lora_fsdp2.yaml | 1 + recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml | 1 + recipes/configs/dev/llama2/7B_lora_fsdp2.yaml | 1 + recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml | 1 + recipes/configs/gemma/2B_lora.yaml | 1 + recipes/configs/gemma/2B_lora_single_device.yaml | 1 + recipes/configs/gemma/2B_qlora_single_device.yaml | 1 + recipes/configs/gemma/7B_lora.yaml | 1 + recipes/configs/gemma/7B_lora_single_device.yaml | 1 + recipes/configs/gemma/7B_qlora_single_device.yaml | 1 + recipes/configs/llama2/13B_lora.yaml | 1 + recipes/configs/llama2/13B_qlora_single_device.yaml | 1 + recipes/configs/llama2/70B_lora.yaml | 1 + recipes/configs/llama2/7B_lora.yaml | 1 + recipes/configs/llama2/7B_lora_dpo.yaml | 1 + recipes/configs/llama2/7B_lora_dpo_single_device.yaml | 1 + recipes/configs/llama2/7B_lora_single_device.yaml | 1 + recipes/configs/llama2/7B_qlora_single_device.yaml | 1 + recipes/configs/llama3/70B_lora.yaml | 1 + recipes/configs/llama3/8B_lora.yaml | 1 + recipes/configs/llama3/8B_lora_single_device.yaml | 1 + recipes/configs/llama3/8B_qlora_single_device.yaml | 1 + recipes/configs/llama3_1/70B_lora.yaml | 1 + recipes/configs/llama3_1/8B_lora.yaml | 1 + recipes/configs/llama3_1/8B_lora_single_device.yaml | 1 + recipes/configs/llama3_1/8B_qlora_single_device.yaml | 1 + recipes/configs/mistral/7B_lora.yaml | 1 + recipes/configs/mistral/7B_lora_single_device.yaml | 1 + recipes/configs/mistral/7B_qlora_single_device.yaml | 1 + recipes/configs/phi3/mini_lora.yaml | 1 + recipes/configs/phi3/mini_lora_single_device.yaml | 1 + recipes/configs/phi3/mini_qlora_single_device.yaml | 1 + recipes/configs/qwen2/0.5B_lora.yaml | 1 + recipes/configs/qwen2/0.5B_lora_single_device.yaml | 1 + recipes/configs/qwen2/1.5B_lora.yaml | 1 + recipes/configs/qwen2/1.5B_lora_single_device.yaml | 1 + recipes/configs/qwen2/7B_lora.yaml | 1 + recipes/configs/qwen2/7B_lora_single_device.yaml | 1 + 41 files changed, 41 insertions(+) diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index 854010aee5..ba5d43113d 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index 7c635d9ff0..c2e990bf7b 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml b/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml index 93991feab1..9b5e411850 100644 --- a/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml @@ -31,6 +31,7 @@ model: apply_lora_to_output: True lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml b/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml index 7b2bb79dd5..75267f398a 100644 --- a/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml @@ -21,6 +21,7 @@ model: apply_lora_to_output: False lora_rank: 16 lora_alpha: 32 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer diff --git a/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml b/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml index 293327376d..9fcfbd8b19 100644 --- a/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml @@ -21,6 +21,7 @@ model: apply_lora_to_output: False lora_rank: 16 lora_alpha: 32 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer diff --git a/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml b/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml index 32a24bd614..fd5eda4b5e 100644 --- a/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml @@ -30,6 +30,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer diff --git a/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml b/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml index d259d98b61..905a8e0acc 100644 --- a/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml @@ -29,6 +29,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml index 52dce07b39..2e345791da 100644 --- a/recipes/configs/gemma/2B_lora.yaml +++ b/recipes/configs/gemma/2B_lora.yaml @@ -34,6 +34,7 @@ model: apply_lora_to_mlp: True lora_rank: 64 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index 74ff398ab3..5b0af37ffa 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -33,6 +33,7 @@ model: apply_lora_to_mlp: True lora_rank: 64 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index 7099112988..d67d79bfaa 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -33,6 +33,7 @@ model: apply_lora_to_mlp: True lora_rank: 64 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml index d1469dd78a..b79bbb6845 100644 --- a/recipes/configs/gemma/7B_lora.yaml +++ b/recipes/configs/gemma/7B_lora.yaml @@ -34,6 +34,7 @@ model: apply_lora_to_mlp: True lora_rank: 64 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index 0a25bae846..a8dfc1878e 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -33,6 +33,7 @@ model: apply_lora_to_mlp: True lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index 3d63a548d1..acb044ece9 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -33,6 +33,7 @@ model: apply_lora_to_mlp: True lora_rank: 64 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml index f3534026de..33d160ece0 100644 --- a/recipes/configs/llama2/13B_lora.yaml +++ b/recipes/configs/llama2/13B_lora.yaml @@ -27,6 +27,7 @@ model: apply_lora_to_output: True lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml index 5f0995c911..74c51fd6bc 100644 --- a/recipes/configs/llama2/13B_qlora_single_device.yaml +++ b/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml index 333b5e4615..2c745402f1 100644 --- a/recipes/configs/llama2/70B_lora.yaml +++ b/recipes/configs/llama2/70B_lora.yaml @@ -17,6 +17,7 @@ model: apply_lora_to_output: False lora_rank: 16 lora_alpha: 32 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml index 6c8e409662..dd527194b3 100644 --- a/recipes/configs/llama2/7B_lora.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -26,6 +26,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer diff --git a/recipes/configs/llama2/7B_lora_dpo.yaml b/recipes/configs/llama2/7B_lora_dpo.yaml index 2f89a96150..81a83577fe 100644 --- a/recipes/configs/llama2/7B_lora_dpo.yaml +++ b/recipes/configs/llama2/7B_lora_dpo.yaml @@ -25,6 +25,7 @@ model: lora_rank: 8 lora_alpha: 16 lora_dropout: 0.0 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml index f85fe9f807..ef70976ee4 100644 --- a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml @@ -24,6 +24,7 @@ model: lora_rank: 8 lora_alpha: 16 lora_dropout: 0.0 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index 3ace543928..dba911a0d7 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -24,6 +24,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index bf710df03c..427cdc50be 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index 8f3979e086..c323a277ed 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -17,6 +17,7 @@ model: apply_lora_to_output: False lora_rank: 16 lora_alpha: 32 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer diff --git a/recipes/configs/llama3/8B_lora.yaml b/recipes/configs/llama3/8B_lora.yaml index 08ac52ea00..26016cd5cd 100644 --- a/recipes/configs/llama3/8B_lora.yaml +++ b/recipes/configs/llama3/8B_lora.yaml @@ -31,6 +31,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelMetaCheckpointer diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 7a7c64a4a0..c35d93b9b4 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -24,6 +24,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index b6524839de..e317ae5f1d 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index 65ebaddabf..995039db36 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -16,6 +16,7 @@ model: apply_lora_to_output: False lora_rank: 16 lora_alpha: 32 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml index 1e9649bcee..7f364a4462 100644 --- a/recipes/configs/llama3_1/8B_lora.yaml +++ b/recipes/configs/llama3_1/8B_lora.yaml @@ -31,6 +31,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index 4121a4e51f..a445051a98 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -24,6 +24,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml index 7a4baf5ca7..6b8b3497c2 100644 --- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/mistral/7B_lora.yaml b/recipes/configs/mistral/7B_lora.yaml index 1a5286761a..a15b437d5a 100644 --- a/recipes/configs/mistral/7B_lora.yaml +++ b/recipes/configs/mistral/7B_lora.yaml @@ -42,6 +42,7 @@ model: apply_lora_to_output: True lora_rank: 64 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml index f42d2d5fc3..1c461f9f46 100644 --- a/recipes/configs/mistral/7B_lora_single_device.yaml +++ b/recipes/configs/mistral/7B_lora_single_device.yaml @@ -39,6 +39,7 @@ model: apply_lora_to_output: True lora_rank: 64 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml index 29929d150f..54d0906150 100644 --- a/recipes/configs/mistral/7B_qlora_single_device.yaml +++ b/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -40,6 +40,7 @@ model: apply_lora_to_output: False lora_rank: 64 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/configs/phi3/mini_lora.yaml b/recipes/configs/phi3/mini_lora.yaml index 4d8116d83f..0556b1de2c 100644 --- a/recipes/configs/phi3/mini_lora.yaml +++ b/recipes/configs/phi3/mini_lora.yaml @@ -25,6 +25,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml index b5783d6f7e..15e2fc02a5 100644 --- a/recipes/configs/phi3/mini_lora_single_device.yaml +++ b/recipes/configs/phi3/mini_lora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml index b5e031745f..5b3ae9e6a1 100644 --- a/recipes/configs/phi3/mini_qlora_single_device.yaml +++ b/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: diff --git a/recipes/configs/qwen2/0.5B_lora.yaml b/recipes/configs/qwen2/0.5B_lora.yaml index 3f67c169d8..dc0f0ebe5f 100644 --- a/recipes/configs/qwen2/0.5B_lora.yaml +++ b/recipes/configs/qwen2/0.5B_lora.yaml @@ -25,6 +25,7 @@ model: apply_lora_to_mlp: False lora_rank: 32 lora_alpha: 64 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.qwen2.qwen2_tokenizer diff --git a/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/recipes/configs/qwen2/0.5B_lora_single_device.yaml index fff48bb626..23bcb1742c 100644 --- a/recipes/configs/qwen2/0.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_lora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_mlp: False lora_rank: 32 lora_alpha: 64 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.qwen2.qwen2_tokenizer diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml index 9d88b3c19e..85e9c6577c 100644 --- a/recipes/configs/qwen2/1.5B_lora.yaml +++ b/recipes/configs/qwen2/1.5B_lora.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_mlp: False lora_rank: 32 lora_alpha: 64 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.qwen2.qwen2_tokenizer diff --git a/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/recipes/configs/qwen2/1.5B_lora_single_device.yaml index c55a485dd3..6c77f84d16 100644 --- a/recipes/configs/qwen2/1.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_lora_single_device.yaml @@ -23,6 +23,7 @@ model: apply_lora_to_mlp: False lora_rank: 32 lora_alpha: 64 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.qwen2.qwen2_tokenizer diff --git a/recipes/configs/qwen2/7B_lora.yaml b/recipes/configs/qwen2/7B_lora.yaml index 6431bcaf56..fa6737e8ed 100644 --- a/recipes/configs/qwen2/7B_lora.yaml +++ b/recipes/configs/qwen2/7B_lora.yaml @@ -26,6 +26,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.qwen2.qwen2_tokenizer diff --git a/recipes/configs/qwen2/7B_lora_single_device.yaml b/recipes/configs/qwen2/7B_lora_single_device.yaml index 0425850cd1..188066e1e2 100644 --- a/recipes/configs/qwen2/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2/7B_lora_single_device.yaml @@ -24,6 +24,7 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.qwen2.qwen2_tokenizer From e27f7365e222cc3050f3a47042629c3f67f0593c Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 5 Sep 2024 13:17:02 -0700 Subject: [PATCH 04/12] typo --- torchtune/models/qwen2/_model_builders.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtune/models/qwen2/_model_builders.py b/torchtune/models/qwen2/_model_builders.py index 9d4a6619c4..f6b031d5a4 100644 --- a/torchtune/models/qwen2/_model_builders.py +++ b/torchtune/models/qwen2/_model_builders.py @@ -155,7 +155,7 @@ def lora_qwen2_7b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation - lora_droppout (float): dropout probability for the low-rank approximation. Default: 0.0 + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: @@ -207,7 +207,7 @@ def lora_qwen2_0_5b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation - lora_droppout (float): dropout probability for the low-rank approximation. Default: 0.0 + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: @@ -264,7 +264,7 @@ def lora_qwen2_1_5b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation - lora_droppout (float): dropout probability for the low-rank approximation. Default: 0.0 + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: From 002d67fa5478938824a194fd57842354aa4a19fa Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Sep 2024 14:20:13 -0700 Subject: [PATCH 05/12] add missing lora dropout --- torchtune/models/llama3/_model_builders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtune/models/llama3/_model_builders.py b/torchtune/models/llama3/_model_builders.py index 6fbb718b2d..282c638b2f 100644 --- a/torchtune/models/llama3/_model_builders.py +++ b/torchtune/models/llama3/_model_builders.py @@ -95,6 +95,7 @@ def lora_llama3_8b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.0, quantize_base: bool = False, use_dora: bool = False, ) -> TransformerDecoder: From 6adf19f9913e37e93ad527ae8a0f20df2ec9609b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Sep 2024 16:34:08 -0700 Subject: [PATCH 06/12] update qwen --- docs/source/api_ref_modules.rst | 1 - recipes/configs/qwen2/1.5B_lora.yaml | 12 +-- torchtune/models/qwen2/_component_builders.py | 77 +++++++------------ torchtune/modules/transformer.py | 4 + 4 files changed, 39 insertions(+), 55 deletions(-) diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index a95d9306f0..f4dd699cca 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -22,7 +22,6 @@ Modeling Components and Building Blocks TransformerSelfAttentionLayer TransformerCrossAttentionLayer TransformerDecoder - TiedEmbeddingTransformerDecoder VisionTransformer Base Tokenizers diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml index 85e9c6577c..35516eefff 100644 --- a/recipes/configs/qwen2/1.5B_lora.yaml +++ b/recipes/configs/qwen2/1.5B_lora.yaml @@ -1,5 +1,5 @@ # Config for multi-device LoRA finetuning in lora_finetune_distributed.py -# using a Qwen2 0.5B model +# using a Qwen2 1.5B model # # This config assumes that you've run the following command before launching # this run: @@ -27,18 +27,18 @@ model: tokenizer: _component_: torchtune.models.qwen2.qwen2_tokenizer - path: /tmp/Qwen2-0.5B-Instruct/vocab.json - merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + path: /tmp/Qwen2-1.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-1.5B-Instruct/merges.txt max_seq_len: null checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_dir: /tmp/Qwen2-1.5B-Instruct checkpoint_files: [ model.safetensors ] recipe_checkpoint: null - output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune + output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune model_type: QWEN2 resume_from_checkpoint: False @@ -67,7 +67,7 @@ max_steps_per_epoch: null gradient_accumulation_steps: 8 # Logging -output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune +output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} diff --git a/torchtune/models/qwen2/_component_builders.py b/torchtune/models/qwen2/_component_builders.py index 734aae1423..048285ce66 100644 --- a/torchtune/models/qwen2/_component_builders.py +++ b/torchtune/models/qwen2/_component_builders.py @@ -5,12 +5,12 @@ # LICENSE file in the root directory of this source tree. from functools import partial -from typing import List, Union +from typing import List from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook from torch import nn - -from torchtune.modules.transformer import TransformerDecoder, TiedEmbeddingTransformerDecoder +import torch.functional as F +from torchtune.modules.transformer import TransformerDecoder from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings from torchtune.modules import ( @@ -48,7 +48,7 @@ def qwen2( norm_eps: float = 1e-5, rope_base: float = 1_000_000.0, tie_word_embeddings: bool = False, -) -> Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]: +) -> TransformerDecoder: """ Build the decoder associated with the Qwen2 model. This includes: - Token embeddings @@ -104,28 +104,20 @@ def qwen2( mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) tok_embeddings = nn.Embedding(vocab_size, embed_dim) - output_proj = None if tie_word_embeddings else nn.Linear(embed_dim, vocab_size, bias=False) - if output_proj is None: - return TiedEmbeddingTransformerDecoder( - tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, - max_seq_len=max_seq_len, - num_heads=num_heads, - head_dim=head_dim, - norm=RMSNorm(embed_dim, eps=norm_eps), - ) + if tie_word_embeddings: + output_proj = lambda x: F.linear(x, tok_embeddings.weight) else: - return TransformerDecoder( - tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, - max_seq_len=max_seq_len, - num_heads=num_heads, - head_dim=head_dim, - norm=RMSNorm(embed_dim, eps=norm_eps), - output=output_proj, - ) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) def qwen2_mlp(dim: int, hidden_dim: int) -> FeedForward: @@ -162,7 +154,7 @@ def lora_qwen2( use_dora: bool = False, # Quantization args quantize_base: bool = False, -) -> Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]: +) -> TransformerDecoder: """ Return a version of Qwen2 (an instance of :func:`~torchtune.models.qwen2.transformer.Qwen2TransformerDecoder`) with LoRA applied based on the passed in configuration. @@ -251,7 +243,7 @@ def lora_qwen2( "apply_lora_to_output is incompatible with tie_word_embeddings," " as there would be no output to apply lora to!" ) - output_proj = None + output_proj = lambda x: F.linear(x, tok_embeddings.weight) else: # TODO: quantize_base is not applied to final output_proj currently. adapter_cls = DoRALinear if use_dora else LoRALinear @@ -260,27 +252,16 @@ def lora_qwen2( if apply_lora_to_output else nn.Linear(embed_dim, vocab_size, bias=False) ) - if output_proj is None: - model = TiedEmbeddingTransformerDecoder( - tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, - max_seq_len=max_seq_len, - num_heads=num_heads, - head_dim=(embed_dim // num_heads), - norm=RMSNorm(embed_dim, eps=norm_eps), - ) - else: - model = TransformerDecoder( - tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, - max_seq_len=max_seq_len, - num_heads=num_heads, - head_dim=(embed_dim // num_heads), - norm=RMSNorm(embed_dim, eps=norm_eps), - output=output_proj, - ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) if quantize_base: # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 65e511e5ea..feb7caa55f 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -516,6 +516,10 @@ def forward( return output +@deprecated( + msg="Please use torchtune.modules.TransformerDecoder instead. \ +If you need an example, see torchtune.models.qwen2._component_builders.py" +) class TiedEmbeddingTransformerDecoder(nn.Module): """ Transformer Decoder with tied embedding weight. A key difference between From c6dd2981b393fd22a9c0368ada8b1aecb97e3712 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Sep 2024 16:37:10 -0700 Subject: [PATCH 07/12] change typehint --- torchtune/models/qwen2/_model_builders.py | 26 +++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torchtune/models/qwen2/_model_builders.py b/torchtune/models/qwen2/_model_builders.py index 1ec9cc53e5..e8e7334d57 100644 --- a/torchtune/models/qwen2/_model_builders.py +++ b/torchtune/models/qwen2/_model_builders.py @@ -7,7 +7,7 @@ from torchtune.models.qwen2._component_builders import qwen2, lora_qwen2 from torchtune.models.qwen2._tokenizer import Qwen2Tokenizer -from torchtune.modules import TransformerDecoder, TiedEmbeddingTransformerDecoder +from torchtune.modules import TransformerDecoder from torchtune.modules.peft import LORA_ATTN_MODULES from torchtune.modules.tokenizers import parse_hf_tokenizer_json from torchtune.data._prompt_templates import _TemplateType @@ -42,17 +42,17 @@ def qwen2_7b() -> TransformerDecoder: ) -def qwen2_0_5b() -> TiedEmbeddingTransformerDecoder: +def qwen2_0_5b() -> TransformerDecoder: """ Builder for creating a Qwen2 model initialized w/ the default 0.5B parameter values from https://huggingface.co/Qwen/Qwen2-0.5B-Instruct Returns: - TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 0.5B model + TransformerDecoder: Instantiation of Qwen2 0.5B model Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default - and returns an instance of `TiedEmbeddingTransformerDecoder`. + and returns an instance of `TransformerDecoder`. """ return qwen2( vocab_size=151936, @@ -69,17 +69,17 @@ def qwen2_0_5b() -> TiedEmbeddingTransformerDecoder: ) -def qwen2_1_5b() -> TiedEmbeddingTransformerDecoder: +def qwen2_1_5b() -> TransformerDecoder: """ Builder for creating a Qwen2 model initialized w/ the default 1.5B parameter values from https://huggingface.co/Qwen/Qwen2-1.5B-Instruct Returns: - TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 1.5B model + TransformerDecoder: Instantiation of Qwen2 1.5B model Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default - and returns an instance of `TiedEmbeddingTransformerDecoder`. + and returns an instance of `TransformerDecoder`. """ return qwen2( vocab_size=151936, @@ -191,7 +191,7 @@ def lora_qwen2_0_5b( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, -) -> TiedEmbeddingTransformerDecoder: +) -> TransformerDecoder: """ Builder for creating a Qwen2 0.5B model with LoRA enabled. @@ -211,11 +211,11 @@ def lora_qwen2_0_5b( quantize_base (bool): Whether to quantize base model weights Returns: - TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 0.5B model with LoRA applied + TransformerDecoder: Instantiation of Qwen2 0.5B model with LoRA applied Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default - and returns an instance of `TiedEmbeddingTransformerDecoder`. + and returns an instance of `TransformerDecoder`. """ return lora_qwen2( lora_attn_modules=lora_attn_modules, @@ -248,7 +248,7 @@ def lora_qwen2_1_5b( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, -) -> TiedEmbeddingTransformerDecoder: +) -> TransformerDecoder: """ Builder for creating a Qwen2 1.5B model with LoRA enabled. @@ -268,11 +268,11 @@ def lora_qwen2_1_5b( quantize_base (bool): Whether to quantize base model weights Returns: - TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 1.5B model with LoRA applied + TransformerDecoder: Instantiation of Qwen2 1.5B model with LoRA applied Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default - and returns an instance of `TiedEmbeddingTransformerDecoder`. + and returns an instance of `TransformerDecoder`. """ return lora_qwen2( lora_attn_modules=lora_attn_modules, From a55f9ae6d9b4fdb6ae3e09ee27a84aa37c6c80ac Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Sep 2024 16:41:17 -0700 Subject: [PATCH 08/12] import deprecated --- torchtune/modules/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index feb7caa55f..55dc27f24c 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -10,8 +10,8 @@ import torch.nn.functional as F from torch import nn from torchtune.modules import MultiHeadAttention - from torchtune.modules.attention_utils import _MaskType +from torchtune.utils.logging import deprecated class TransformerSelfAttentionLayer(nn.Module): From b427bf5f7033dccf07cf0f2614faf898f9538ffe Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Sep 2024 16:43:11 -0700 Subject: [PATCH 09/12] update import --- torchtune/models/qwen2/_component_builders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/models/qwen2/_component_builders.py b/torchtune/models/qwen2/_component_builders.py index 048285ce66..fe4aa0fcf9 100644 --- a/torchtune/models/qwen2/_component_builders.py +++ b/torchtune/models/qwen2/_component_builders.py @@ -9,7 +9,7 @@ from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook from torch import nn -import torch.functional as F +import torch.nn.functional as F from torchtune.modules.transformer import TransformerDecoder from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings From f54904e1aecac48b4dc1bab5eab8b4b3f13b37e5 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Sep 2024 19:15:37 -0700 Subject: [PATCH 10/12] add tied linear --- docs/source/api_ref_modules.rst | 1 + torchtune/models/qwen2/_component_builders.py | 5 +-- torchtune/modules/__init__.py | 3 +- torchtune/modules/tied_linear.py | 34 +++++++++++++++++++ torchtune/modules/transformer.py | 6 ++-- 5 files changed, 43 insertions(+), 6 deletions(-) create mode 100644 torchtune/modules/tied_linear.py diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index f4dd699cca..2c0b0fc735 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -19,6 +19,7 @@ Modeling Components and Building Blocks RMSNorm Fp32LayerNorm TanhGate + TiedLinear TransformerSelfAttentionLayer TransformerCrossAttentionLayer TransformerDecoder diff --git a/torchtune/models/qwen2/_component_builders.py b/torchtune/models/qwen2/_component_builders.py index fe4aa0fcf9..70eec1b4d7 100644 --- a/torchtune/models/qwen2/_component_builders.py +++ b/torchtune/models/qwen2/_component_builders.py @@ -18,6 +18,7 @@ FeedForward, RMSNorm, TransformerSelfAttentionLayer, + TiedLinear ) @@ -105,7 +106,7 @@ def qwen2( ) tok_embeddings = nn.Embedding(vocab_size, embed_dim) if tie_word_embeddings: - output_proj = lambda x: F.linear(x, tok_embeddings.weight) + output_proj = TiedLinear(tok_embeddings) else: output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return TransformerDecoder( @@ -243,7 +244,7 @@ def lora_qwen2( "apply_lora_to_output is incompatible with tie_word_embeddings," " as there would be no output to apply lora to!" ) - output_proj = lambda x: F.linear(x, tok_embeddings.weight) + output_proj = TiedLinear(tok_embeddings) else: # TODO: quantize_base is not applied to final output_proj currently. adapter_cls = DoRALinear if use_dora else LoRALinear diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 695a7c6f1f..7e5f0cf5d9 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -15,6 +15,7 @@ from .position_embeddings import RotaryPositionalEmbeddings # noqa from .rms_norm import RMSNorm # noqa from .tanh_gate import TanhGate # noqa +from .tied_linear import TiedLinear # noqa from .transformer import ( # noqa TiedEmbeddingTransformerDecoder, TransformerCrossAttentionLayer, @@ -32,7 +33,7 @@ "KVCache", "RotaryPositionalEmbeddings", "RMSNorm", - "Fp32LayerNorm", + "TiedLinear" "Fp32LayerNorm", "VisionTransformer", "TransformerDecoder", "TiedEmbeddingTransformerDecoder", diff --git a/torchtune/modules/tied_linear.py b/torchtune/modules/tied_linear.py new file mode 100644 index 0000000000..6864f6fa1b --- /dev/null +++ b/torchtune/modules/tied_linear.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TiedLinear: + """ + A tied linear layer, without bias, that shares the same weight as another linear layer. + This is useful for models that use tied weights, such as qwen and gemma. + It requires as input an nn.Module, instead of the weight of the module, so it + can work with FSDP. Otherwise, the memory reference will be lost after FSDP is applied. + + Args: + tied_module (nn.Module): The module whose weight is shared. Only + the weight is used. The bias is ignored. + Raises: + AttributeError: If the provided module does not have an attribute 'weight'. + """ + + def __init__(self, tied_module: nn.Module): + self.tied_module = tied_module + if not hasattr(tied_module, "weight"): + raise AttributeError( + "Provided module does not have attribute 'weight'. Please check your tied_module." + ) + + def __call__(self, x: torch.tensor) -> torch.tensor: + return F.linear(x, self.tied_module.weight) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 55dc27f24c..6c0e65b811 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -295,7 +295,7 @@ class TransformerDecoder(nn.Module): to setup the :func:`~torchtune.modules.KVCache` norm (nn.Module): Callable that applies normalization to the output of the decoder, before final MLP. - output (nn.Linear): Callable that applies a linear transformation to the output of + output (Union[nn.Linear, Callable]): Callable that applies a linear transformation to the output of the decoder. num_layers (Optional[int]): Number of Transformer Decoder layers, only define when layers is not a list. @@ -320,7 +320,7 @@ def __init__( num_heads: int, head_dim: int, norm: nn.Module, - output: nn.Linear, + output: Union[nn.Linear, Callable], num_layers: Optional[int] = None, output_hidden_states: Optional[List[int]] = None, ) -> None: From a0bd26b2bab2821c2b90892ecb352df8022fbfbd Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Sep 2024 19:22:24 -0700 Subject: [PATCH 11/12] remove unused import --- torchtune/models/qwen2/_component_builders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtune/models/qwen2/_component_builders.py b/torchtune/models/qwen2/_component_builders.py index 70eec1b4d7..716fe337ad 100644 --- a/torchtune/models/qwen2/_component_builders.py +++ b/torchtune/models/qwen2/_component_builders.py @@ -9,7 +9,6 @@ from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook from torch import nn -import torch.nn.functional as F from torchtune.modules.transformer import TransformerDecoder from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings From b26b4fc80382d5f9ff16ff1b7dd93068904d3c9f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 12 Sep 2024 10:40:58 -0700 Subject: [PATCH 12/12] fix comments --- torchtune/modules/__init__.py | 3 ++- torchtune/modules/transformer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 7e5f0cf5d9..66076d52f9 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -33,7 +33,8 @@ "KVCache", "RotaryPositionalEmbeddings", "RMSNorm", - "TiedLinear" "Fp32LayerNorm", + "TiedLinear", + "Fp32LayerNorm", "VisionTransformer", "TransformerDecoder", "TiedEmbeddingTransformerDecoder", diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 6c0e65b811..714924c69d 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -518,7 +518,8 @@ def forward( @deprecated( msg="Please use torchtune.modules.TransformerDecoder instead. \ -If you need an example, see torchtune.models.qwen2._component_builders.py" +If you need an example, see torchtune.models.qwen2._component_builders.py \ +and how to implement torch.modules.TiedLinear for the output projection." ) class TiedEmbeddingTransformerDecoder(nn.Module): """