From c4a12c05db302715340403ef1e16b7e691675d7c Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Mon, 30 Sep 2024 17:51:47 -0700 Subject: [PATCH 01/11] [wip] QLoRA with bias + Llama 3.2 Vision QLoRA configs --- recipes/configs/llama3_2_vision/11B_lora.yaml | 2 +- .../11B_lora_single_device.yaml | 2 +- .../configs/llama3_2_vision/11B_qlora.yaml | 88 +++++++++ .../11B_qlora_single_device.yaml | 112 +++++++++++ torchtune/_recipe_registry.py | 8 + torchtune/models/clip/_component_builders.py | 175 +++++++++++------- torchtune/modules/low_precision/nf4_linear.py | 14 +- torchtune/modules/peft/lora.py | 6 +- 8 files changed, 326 insertions(+), 81 deletions(-) create mode 100644 recipes/configs/llama3_2_vision/11B_qlora.yaml create mode 100644 recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml diff --git a/recipes/configs/llama3_2_vision/11B_lora.yaml b/recipes/configs/llama3_2_vision/11B_lora.yaml index 166eb7198e..3c051ff6ad 100644 --- a/recipes/configs/llama3_2_vision/11B_lora.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora.yaml @@ -80,7 +80,7 @@ enable_activation_offloading: False dtype: bf16 # Logging -output_dir: /tmp/full-llama3.2-vision-finetune +output_dir: /tmp/lora-llama3.2-vision-finetune metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index 7a8215ca41..3b62523738 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -79,7 +79,7 @@ enable_activation_offloading: False dtype: bf16 # Logging -output_dir: /tmp/full-llama3.2-vision-finetune +output_dir: /tmp/lora-llama3.2-vision-finetune metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs diff --git a/recipes/configs/llama3_2_vision/11B_qlora.yaml b/recipes/configs/llama3_2_vision/11B_qlora.yaml new file mode 100644 index 0000000000..166eb7198e --- /dev/null +++ b/recipes/configs/llama3_2_vision/11B_qlora.yaml @@ -0,0 +1,88 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Llama3.2 11B Vision Instruct model +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct +# +# To launch on 2 devices, run the following command from root: +# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training: +# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 11B_lora_single_device.yaml +# or 11B_qlora_single_device.yaml + +# Model arguments +model: + _component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_11b + decoder_trainable: "frozen" + encoder_trainable: "lora" + fusion_trainable: "lora" + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.0 + image_size: 560 # Make sure this matches the image_size in tokenizer + +# Transform +tokenizer: + _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform + path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model + image_size: 560 + +# Checkpointer +checkpointer: + _component_: torchtune.training.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original/ + checkpoint_files: [consolidated.pth] + recipe_checkpoint: null + output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ + model_type: LLAMA3_VISION +resume_from_checkpoint: False + +# Dataset +dataset: + _component_: torchtune.datasets.multimodal.the_cauldron_dataset + subset: ocrvqa +seed: null +shuffle: True +collate_fn: torchtune.data.padded_collate_tiled_images_and_mask + +# Fine-tuning arguments +epochs: 1 +max_steps_per_epoch: null +batch_size: 2 +gradient_accumulation_steps: 4 +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 2e-5 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: 1.0 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False +dtype: bf16 + +# Logging +output_dir: /tmp/full-llama3.2-vision-finetune +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml new file mode 100644 index 0000000000..97f4612f1d --- /dev/null +++ b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml @@ -0,0 +1,112 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Llama3.2 11B Vision Instruct model +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training: +# tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model arguments +model: + _component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b + decoder_trainable: "frozen" + encoder_trainable: "lora" + fusion_trainable: "lora" + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.0 + image_size: 560 # Make sure this matches the image_size in tokenizer + +# Transform +tokenizer: + _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform + path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model + image_size: 560 + +# Checkpointer +checkpointer: + _component_: torchtune.training.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original/ + checkpoint_files: [consolidated.pth] + recipe_checkpoint: null + output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ + model_type: LLAMA3_VISION +resume_from_checkpoint: False + +# Dataset +dataset: + _component_: torchtune.datasets.multimodal.the_cauldron_dataset + subset: ocrvqa +seed: null +shuffle: True +collate_fn: torchtune.data.padded_collate_tiled_images_and_mask + +# Fine-tuning arguments +epochs: 1 +max_steps_per_epoch: null +batch_size: 2 +gradient_accumulation_steps: 16 +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 2e-5 +optimizer_in_bwd: False +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: 1.0 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False +dtype: bf16 + +# Logging +output_dir: /tmp/qlora-llama3.2-vision-finetune +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: True + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 1 + warmup_steps: 2 + active_steps: 1 + num_cycles: 1 diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index e1c7f8c3c5..7edeca5719 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -216,6 +216,10 @@ class Recipe: name="llama3_2_vision/11B_lora_single_device", file_path="llama3_2_vision/11B_lora_single_device.yaml", ), + Config( + name="llama3_2_vision/11B_qlora_single_device", + file_path="llama3_2_vision/11B_qlora_single_device.yaml", + ), ], supports_distributed=False, ), @@ -289,6 +293,10 @@ class Recipe: name="llama3_2_vision/11B_lora", file_path="llama3_2_vision/11B_lora.yaml", ), + Config( + name="llama3_2_vision/11B_qlora", + file_path="llama3_2_vision/11B_qlora.yaml", + ), ], supports_distributed=True, ), diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index 0940d49359..88580ca78b 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -9,22 +9,27 @@ import torch from torch import nn - -from torchtune.modules.vision_transformer import VisionTransformer, CLSProjection -from torchtune.models.clip._position_embeddings import TokenPositionalEmbedding, TiledTokenPositionalEmbedding, TilePositionalEmbedding +from torchtune.models.clip._position_embeddings import ( + TiledTokenPositionalEmbedding, + TilePositionalEmbedding, + TokenPositionalEmbedding, +) from torchtune.modules import ( - TransformerSelfAttentionLayer, + FeedForward, + Fp32LayerNorm, + FrozenNF4Linear, MultiHeadAttention, TanhGate, - FeedForward, - Fp32LayerNorm + TransformerSelfAttentionLayer, ) from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear +from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer + def clip_vision_encoder( tile_size: int, @@ -43,7 +48,7 @@ def clip_vision_encoder( ) -> VisionTransformer: """ Builds the vision encoder associated with the clip model. This includes: - + - TransformerEncoderLayer - positional embeddings - CLS projection (optional) @@ -82,21 +87,25 @@ def clip_vision_encoder( """ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None + cls_projection = ( + CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) + if output_cls_projection + else None + ) # transformer layer self_attn = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_heads, - head_dim=embed_dim // num_heads, - q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - pos_embeddings=None, - attn_dropout=0.0, - is_causal=False, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=embed_dim // num_heads, + q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + pos_embeddings=None, + attn_dropout=0.0, + is_causal=False, ) mlp = clip_mlp( in_dim=embed_dim, @@ -107,8 +116,8 @@ def clip_vision_encoder( transformer_layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, - sa_norm= Fp32LayerNorm(embed_dim, eps=1e-5), - mlp_norm= Fp32LayerNorm(embed_dim, eps=1e-5), + sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5), + mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5), sa_scale=None, mlp_scale=None, ) @@ -118,17 +127,21 @@ def clip_vision_encoder( pre_tile_pos_embed = None post_tile_pos_embed = None token_pos_embedding = TokenPositionalEmbedding( - embed_dim=embed_dim, - patch_size=patch_size, - tile_size=tile_size) + embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size + ) else: - pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) - post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) + pre_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, embed_dim=embed_dim + ) + post_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, embed_dim=embed_dim + ) token_pos_embedding = TiledTokenPositionalEmbedding( - max_num_tiles=max_num_tiles, - embed_dim=embed_dim, - patch_size=patch_size, - tile_size=tile_size) + max_num_tiles=max_num_tiles, + embed_dim=embed_dim, + patch_size=patch_size, + tile_size=tile_size, + ) return VisionTransformer( num_layers=num_layers, @@ -145,13 +158,29 @@ def clip_vision_encoder( ) -def clip_mlp(in_dim: int, out_dim: int, hidden_dim: int, activation: nn.Module, quantize_base: bool = False) -> FeedForward: +def clip_mlp( + in_dim: int, + out_dim: int, + hidden_dim: int, + activation: nn.Module, + quantize_base: bool = False, +) -> FeedForward: """ Build the MLP layer associated with the clip model. """ - gate_proj = nn.Linear(in_dim, hidden_dim) if not quantize_base else FrozenNF4Linear(in_dim, hidden_dim) - down_proj = nn.Linear(hidden_dim, out_dim) if not quantize_base else FrozenNF4Linear(hidden_dim, out_dim) - return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation) + gate_proj = ( + nn.Linear(in_dim, hidden_dim) + if not quantize_base + else FrozenNF4Linear(in_dim, hidden_dim) + ) + down_proj = ( + nn.Linear(hidden_dim, out_dim) + if not quantize_base + else FrozenNF4Linear(hidden_dim, out_dim) + ) + return FeedForward( + gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation + ) # ------------------ LoRA CLIP ------------------ @@ -222,7 +251,7 @@ def lora_clip_vision_encoder( quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. - + Returns: VisionTransformer: Instantiation of VisionTransformer model. @@ -230,34 +259,38 @@ def lora_clip_vision_encoder( assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" # TODO: add support for quantizing and LoRA for the final output projection - cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None + cls_projection = ( + CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) + if output_cls_projection + else None + ) # transformer layer self_attn = lora_clip_attention( - lora_modules=lora_modules, - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_heads, - head_dim=embed_dim // num_heads, - attn_dropout=0.0, + lora_modules=lora_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=embed_dim // num_heads, + attn_dropout=0.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + if apply_lora_to_mlp: + mlp = lora_clip_mlp( + in_dim=embed_dim, + hidden_dim=4 * embed_dim, + out_dim=embed_dim, + activation=activation(), lora_rank=lora_rank, lora_alpha=lora_alpha, + quantize_base=quantize_base, lora_dropout=lora_dropout, use_dora=use_dora, - quantize_base=quantize_base, - ) - if apply_lora_to_mlp: - mlp = lora_clip_mlp( - in_dim=embed_dim, - hidden_dim=4 * embed_dim, - out_dim=embed_dim, - activation=activation(), - lora_rank=lora_rank, - lora_alpha=lora_alpha, - quantize_base=quantize_base, - lora_dropout=lora_dropout, - use_dora=use_dora, - ) + ) else: mlp = clip_mlp( in_dim=embed_dim, @@ -269,8 +302,8 @@ def lora_clip_vision_encoder( transformer_layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, - sa_norm= Fp32LayerNorm(embed_dim, eps=1e-5), - mlp_norm= Fp32LayerNorm(embed_dim, eps=1e-5), + sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5), + mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5), sa_scale=None, mlp_scale=None, ) @@ -280,17 +313,21 @@ def lora_clip_vision_encoder( pre_tile_pos_embed = None post_tile_pos_embed = None token_pos_embedding = TokenPositionalEmbedding( - embed_dim=embed_dim, - patch_size=patch_size, - tile_size=tile_size) + embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size + ) else: - pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) - post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) + pre_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, embed_dim=embed_dim + ) + post_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, embed_dim=embed_dim + ) token_pos_embedding = TiledTokenPositionalEmbedding( - max_num_tiles=max_num_tiles, - embed_dim=embed_dim, - patch_size=patch_size, - tile_size=tile_size) + max_num_tiles=max_num_tiles, + embed_dim=embed_dim, + patch_size=patch_size, + tile_size=tile_size, + ) model = VisionTransformer( num_layers=num_layers, @@ -482,4 +519,6 @@ def lora_clip_mlp( dropout=lora_dropout, quantize_base=quantize_base, ) - return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation) + return FeedForward( + gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation + ) diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py index 9b0eaf53a3..ae99568820 100644 --- a/torchtune/modules/low_precision/nf4_linear.py +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -18,7 +18,6 @@ class FrozenNF4Linear(nn.Linear): NF4Tensor as its weight. This class also freezes its ``weight`` parameter and is meant to be used as the base Linear layer for modeling use cases such as QLoRA where base model parameters are frozen. - NOTE: biases are currently not supported. Args: in_dim (int): input dimension @@ -27,18 +26,16 @@ class FrozenNF4Linear(nn.Linear): device given by `torch.get_default_device()`. **kwargs: any additional arguments to pass to the underlying Linear layer. - Raises: - RuntimeError: if ``bias`` is set to ``True`` """ def __init__( self, in_dim: int, out_dim: int, device: Optional[torch.device] = None, **kwargs ): - if "bias" in kwargs and kwargs.pop("bias"): - raise RuntimeError("FrozenNF4Linear does not currently support biases!") - super().__init__(in_dim, out_dim, device=device, bias=False, **kwargs) + super().__init__(in_dim, out_dim, device=device, **kwargs) self.weight.requires_grad_(False) + if self.bias is not None: + self.bias.requires_grad_(False) self.nf4_weight = to_nf4(self.weight) # re-register self.weight as the nf4 weight, so that the nf4 weight # shows up as expected in .parameters, state_dict, etc. @@ -57,4 +54,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: Returns: Tensor: output tensor """ - return linear_nf4(input=input, weight=self.weight) + out = linear_nf4(input=input, weight=self.weight) + if self.bias is not None: + out = out + self.bias + return out diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 57d72af672..3b90a89306 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -95,10 +95,6 @@ def _create_weight_and_bias(self): weight = linear.weight if not self._quantize_base else to_nf4(linear.weight) bias = None if self.use_bias: - if self._quantize_base: - raise NotImplementedError( - "Quantized LoRALinear does not support bias at the moment." - ) bias = linear.bias return weight, bias @@ -123,6 +119,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ if self._quantize_base: out = linear_nf4(input=x, weight=self.weight) + if self.use_bias: + out = out + self.bias else: out = F.linear(x, self.weight, self.bias) if self.disabled: From fc662fdef4cdb0b15f4ebd8efd6bf00bbb8ecbdf Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Mon, 30 Sep 2024 17:56:19 -0700 Subject: [PATCH 02/11] couple config fixes --- recipes/configs/llama3_2_vision/11B_qlora.yaml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/recipes/configs/llama3_2_vision/11B_qlora.yaml b/recipes/configs/llama3_2_vision/11B_qlora.yaml index 166eb7198e..b9a9b25d54 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora.yaml @@ -5,19 +5,18 @@ # tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct # # To launch on 2 devices, run the following command from root: -# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_lora +# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_qlora # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training: -# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_lora checkpointer.checkpoint_dir= +# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_qlora checkpointer.checkpoint_dir= # # This config works best when the model is being fine-tuned on 2+ GPUs. -# For single device LoRA finetuning please use 11B_lora_single_device.yaml -# or 11B_qlora_single_device.yaml +# For single device QLoRA finetuning please use 11B_qlora_single_device.yaml # Model arguments model: - _component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_11b + _component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b decoder_trainable: "frozen" encoder_trainable: "lora" fusion_trainable: "lora" @@ -80,7 +79,7 @@ enable_activation_offloading: False dtype: bf16 # Logging -output_dir: /tmp/full-llama3.2-vision-finetune +output_dir: /tmp/qlora-llama3.2-vision-finetune metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs From 8c7cd3d86154e0271088822be2ac11e8e2590ba9 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Mon, 30 Sep 2024 17:56:26 -0700 Subject: [PATCH 03/11] couple config fixes --- recipes/configs/llama3_2_vision/11B_qlora.yaml | 2 +- recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/configs/llama3_2_vision/11B_qlora.yaml b/recipes/configs/llama3_2_vision/11B_qlora.yaml index b9a9b25d54..d0c3069a4b 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora.yaml @@ -1,4 +1,4 @@ -# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# Config for multi-device QLoRA finetuning in lora_finetune_distributed.py # using a Llama3.2 11B Vision Instruct model # # This config assumes that you've run the following command before launching: diff --git a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml index 97f4612f1d..c253f247e6 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml @@ -1,4 +1,4 @@ -# Config for single device LoRA finetuning in lora_finetune_single_device.py +# Config for single device QLoRA finetuning in lora_finetune_single_device.py # using a Llama3.2 11B Vision Instruct model # # This config assumes that you've run the following command before launching: From 63057b3ba75a52e56dae42940e5eb7a42b64f69a Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Wed, 23 Oct 2024 13:50:29 -0700 Subject: [PATCH 04/11] unit tests --- .../modules/low_precision/test_nf4_linear.py | 9 +-- tests/torchtune/modules/peft/test_lora.py | 70 +++++++++---------- torchtune/modules/low_precision/nf4_linear.py | 11 ++- 3 files changed, 45 insertions(+), 45 deletions(-) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index e29a87ba57..fcdb81c260 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -40,10 +40,6 @@ class TestNF4Linear: Class for testing our NF4Linear implementation. """ - def test_bias_unsupported(self): - with pytest.raises(RuntimeError, match="does not currently support biases"): - _ = FrozenNF4Linear(1, 1, bias=True) - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_parameters(self, dtype): nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) @@ -59,9 +55,10 @@ def test_state_dict(self, dtype): assert isinstance(state_dict["weight"], NF4Tensor) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) - def test_output_dtype(self, dtype): + @pytest.mark.parametrize("bias", [True, False]) + def test_output_dtype(self, dtype, bias): # Test to ensure W4 A16 produces A16 / W4A32 produces A32 - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype, bias=bias) inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) out = nf4_linear(inp) assert out.dtype == dtype diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py index caca54b86b..118079ef77 100644 --- a/tests/torchtune/modules/peft/test_lora.py +++ b/tests/torchtune/modules/peft/test_lora.py @@ -61,19 +61,22 @@ def lora_linear(self, in_dim, out_dim) -> LoRALinear: return lora_linear @pytest.fixture - def qlora_linear(self, in_dim, out_dim) -> LoRALinear: - with training.set_default_dtype(torch.bfloat16): - qlora_linear = LoRALinear( - in_dim=512, - out_dim=512, - rank=RANK, - alpha=ALPHA, - use_bias=False, - quantize_base=True, - ) - fixed_init_model(qlora_linear, dtype=torch.bfloat16) + def qlora_linear(self): + def create_qlora_linear(use_bias, dtype): + with training.set_default_dtype(dtype): + qlora_linear = LoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=use_bias, + quantize_base=True, + ) + # fixed_init_model(qlora_linear, dtype=torch.bfloat16) return qlora_linear + return create_qlora_linear + @torch.no_grad() def set_dummy_weights_for_merge(self, lora_module): lora_module.lora_a.weight = nn.Parameter( @@ -97,50 +100,45 @@ def test_forward(self, inputs, lora_linear, out_dim) -> None: assert actual.shape == (BSZ, SEQ_LEN, out_dim) torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6) - def test_lora_weight_nf4_when_quantized(self, qlora_linear): + @pytest.mark.parametrize("use_bias", [True, False]) + def test_lora_weight_nf4_when_quantized(self, use_bias, qlora_linear): + qlora_linear = qlora_linear(use_bias=use_bias, dtype=torch.bfloat16) assert isinstance(qlora_linear.weight, NF4Tensor) - - def test_quantize_with_bias_raises(self): - with pytest.raises(NotImplementedError, match="does not support bias"): - LoRALinear( - in_dim=512, - out_dim=512, - rank=RANK, - alpha=ALPHA, - use_bias=True, - quantize_base=True, - ) - - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) - def test_qlora_parity(self, dtype): + if use_bias: + assert not isinstance(qlora_linear.bias, NF4Tensor) + assert qlora_linear.bias.dtype == torch.bfloat16 + + # Note: with bfloat16 F.linear(x, weight, bias) != F.linear(x, weight) + bias. + # This means we would get different results (irrespective of QLoRA). + # So we leave that test case out + @pytest.mark.parametrize( + "use_bias, dtype", + [(False, torch.bfloat16), (True, torch.float32), (False, torch.float32)], + ) + def test_qlora_parity(self, use_bias, dtype, qlora_linear): + qlora_linear = qlora_linear(use_bias=use_bias, dtype=dtype) with training.set_default_dtype(dtype): - qlora_linear = LoRALinear( - in_dim=512, - out_dim=512, - rank=RANK, - alpha=ALPHA, - use_bias=False, - quantize_base=True, - ) lora_linear = LoRALinear( in_dim=512, out_dim=512, rank=RANK, alpha=ALPHA, - use_bias=False, + use_bias=use_bias, quantize_base=False, ) # set weight of lora_linear to unquantized weight of qlora_linear and check # parity. lora_linear.weight.data = qlora_linear.weight.to(dtype) - + if use_bias: + lora_linear.bias.data = qlora_linear.bias.detach().clone() # Ensure forward passes are the same. This is because LoRALinear should use a special # quantized linear operator that runs compute in higher prec (but only saves the 4 bit quantized tensor) # for autograd. inputs = torch.randn(BSZ, SEQ_LEN, 512, dtype=dtype) lora_linear_out = lora_linear(inputs) qlora_linear_out = qlora_linear(inputs) + torch.testing.assert_close(lora_linear_out, qlora_linear_out) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py index ae99568820..593acf861e 100644 --- a/torchtune/modules/low_precision/nf4_linear.py +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -24,15 +24,20 @@ class FrozenNF4Linear(nn.Linear): out_dim (int): output dimension device (Optional[torch.device]): device to use for the underlying weight. If ``None``, uses the default device given by `torch.get_default_device()`. + bias (bool): whether to include bias in the original linear layer. Default: False **kwargs: any additional arguments to pass to the underlying Linear layer. """ def __init__( - self, in_dim: int, out_dim: int, device: Optional[torch.device] = None, **kwargs + self, + in_dim: int, + out_dim: int, + device: Optional[torch.device] = None, + bias: bool = False, + **kwargs, ): - - super().__init__(in_dim, out_dim, device=device, **kwargs) + super().__init__(in_dim, out_dim, device=device, bias=bias, **kwargs) self.weight.requires_grad_(False) if self.bias is not None: self.bias.requires_grad_(False) From d5960b1f729efbef10c397f39cc2e77a56dfe2ca Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Wed, 23 Oct 2024 14:28:33 -0700 Subject: [PATCH 05/11] bias=true in clip_mlp --- recipes/lora_finetune_single_device.py | 3 +++ torchtune/models/clip/_component_builders.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 5d39b72086..144c5892a6 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -429,6 +429,9 @@ def _setup_model( ) else: lora_missing, lora_unexpected = None, None + import pdb + + pdb.set_trace() validate_missing_and_unexpected_for_lora( lora_attn_modules=self._lora_attn_modules, apply_lora_to_mlp=self._apply_lora_to_mlp, diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index a2810a76f4..772d1e32df 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -170,12 +170,12 @@ def clip_mlp( gate_proj = ( nn.Linear(in_dim, hidden_dim) if not quantize_base - else FrozenNF4Linear(in_dim, hidden_dim) + else FrozenNF4Linear(in_dim, hidden_dim, bias=True) ) down_proj = ( nn.Linear(hidden_dim, out_dim) if not quantize_base - else FrozenNF4Linear(hidden_dim, out_dim) + else FrozenNF4Linear(hidden_dim, out_dim, bias=True) ) return FeedForward( gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation From 9112c6a9148d09975f8c4f6b433ea420f5a76593 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Wed, 23 Oct 2024 14:33:50 -0700 Subject: [PATCH 06/11] remove debug code --- recipes/lora_finetune_single_device.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 144c5892a6..5d39b72086 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -429,9 +429,6 @@ def _setup_model( ) else: lora_missing, lora_unexpected = None, None - import pdb - - pdb.set_trace() validate_missing_and_unexpected_for_lora( lora_attn_modules=self._lora_attn_modules, apply_lora_to_mlp=self._apply_lora_to_mlp, From aaaaffba084fb30adcd4562f4cae0f273b2f3b7e Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Wed, 23 Oct 2024 15:08:54 -0700 Subject: [PATCH 07/11] some config merges --- recipes/configs/llama3_2_vision/11B_qlora.yaml | 5 +++-- recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/recipes/configs/llama3_2_vision/11B_qlora.yaml b/recipes/configs/llama3_2_vision/11B_qlora.yaml index d0c3069a4b..1217fb367a 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora.yaml @@ -33,6 +33,7 @@ tokenizer: _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model image_size: 560 + max_seq_len: 8192 # Checkpointer checkpointer: @@ -63,12 +64,12 @@ optimizer: weight_decay: 0.01 lr: 2e-5 lr_scheduler: - _component_: torchtune.modules.get_cosine_schedule_with_warmup + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False +compile: False # set it to True for better memory and performance # Training env device: cuda diff --git a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml index c253f247e6..b12d51237c 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml @@ -32,6 +32,7 @@ tokenizer: _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model image_size: 560 + max_seq_len: 8192 # Checkpointer checkpointer: @@ -63,12 +64,12 @@ optimizer: lr: 2e-5 optimizer_in_bwd: False lr_scheduler: - _component_: torchtune.modules.get_cosine_schedule_with_warmup + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False +compile: False # set it to True for better memory and performance # Training env device: cuda From 7e2c9535e3e97e1d479814b80ffee30870fc7ce4 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Thu, 24 Oct 2024 13:43:33 -0700 Subject: [PATCH 08/11] address comments --- tests/torchtune/modules/peft/test_lora.py | 3 ++- torchtune/modules/low_precision/nf4_linear.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py index 118079ef77..f9df68bd28 100644 --- a/tests/torchtune/modules/peft/test_lora.py +++ b/tests/torchtune/modules/peft/test_lora.py @@ -72,7 +72,7 @@ def create_qlora_linear(use_bias, dtype): use_bias=use_bias, quantize_base=True, ) - # fixed_init_model(qlora_linear, dtype=torch.bfloat16) + fixed_init_model(qlora_linear, dtype=torch.bfloat16) return qlora_linear return create_qlora_linear @@ -126,6 +126,7 @@ def test_qlora_parity(self, use_bias, dtype, qlora_linear): use_bias=use_bias, quantize_base=False, ) + fixed_init_model(lora_linear, dtype=torch.bfloat16) # set weight of lora_linear to unquantized weight of qlora_linear and check # parity. diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py index 593acf861e..0c387ffa0f 100644 --- a/torchtune/modules/low_precision/nf4_linear.py +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -24,7 +24,7 @@ class FrozenNF4Linear(nn.Linear): out_dim (int): output dimension device (Optional[torch.device]): device to use for the underlying weight. If ``None``, uses the default device given by `torch.get_default_device()`. - bias (bool): whether to include bias in the original linear layer. Default: False + bias (bool): whether to include bias in the linear layer. Default: False **kwargs: any additional arguments to pass to the underlying Linear layer. """ @@ -41,11 +41,11 @@ def __init__( self.weight.requires_grad_(False) if self.bias is not None: self.bias.requires_grad_(False) - self.nf4_weight = to_nf4(self.weight) + nf4_weight = to_nf4(self.weight) # re-register self.weight as the nf4 weight, so that the nf4 weight # shows up as expected in .parameters, state_dict, etc. torch.utils.swap_tensors( - self.weight, torch.nn.Parameter(self.nf4_weight, requires_grad=False) + self.weight, torch.nn.Parameter(nf4_weight, requires_grad=False) ) def forward(self, input: torch.Tensor) -> torch.Tensor: From 46f6dd0d0e6454fde1b542179a94a2b464d18931 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Thu, 24 Oct 2024 16:27:36 -0700 Subject: [PATCH 09/11] update dora (and unit test), fix checkpoint save --- tests/torchtune/modules/peft/test_dora.py | 107 +++++++++--------- tests/torchtune/modules/peft/test_lora.py | 2 +- .../llama3_2_vision/_component_builders.py | 97 +++++++++++----- torchtune/modules/peft/_utils.py | 1 + torchtune/modules/peft/dora.py | 25 ++-- 5 files changed, 139 insertions(+), 93 deletions(-) diff --git a/tests/torchtune/modules/peft/test_dora.py b/tests/torchtune/modules/peft/test_dora.py index 02954fb685..1b48608852 100644 --- a/tests/torchtune/modules/peft/test_dora.py +++ b/tests/torchtune/modules/peft/test_dora.py @@ -49,80 +49,77 @@ def inputs(self, in_dim) -> torch.Tensor: return inputs @pytest.fixture - def dora_linear(self, in_dim, out_dim) -> DoRALinear: - dora_linear = DoRALinear( - in_dim=in_dim, - out_dim=out_dim, - rank=RANK, - alpha=ALPHA, - use_bias=False, - ) + def dora_linear(self, in_dim, out_dim): + def create_dora_linear(use_bias, dtype, in_dim=in_dim, out_dim=out_dim): + with training.set_default_dtype(dtype): + dora_linear = DoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=RANK, + alpha=ALPHA, + use_bias=use_bias, + ) - fixed_init_model(dora_linear) - return dora_linear + fixed_init_model(dora_linear) + return dora_linear + + return create_dora_linear @pytest.fixture - def qdora_linear(self, in_dim, out_dim) -> DoRALinear: - with training.set_default_dtype(torch.bfloat16): - qdora_linear = DoRALinear( - in_dim=512, - out_dim=512, - rank=RANK, - alpha=ALPHA, - use_bias=False, - quantize_base=True, - ) - fixed_init_model(qdora_linear, dtype=torch.bfloat16) + def qdora_linear(self): + def create_qdora_linear( + use_bias=False, dtype=torch.bfloat16, in_dim=512, out_dim=512 + ): + with training.set_default_dtype(dtype): + qdora_linear = DoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=RANK, + alpha=ALPHA, + use_bias=use_bias, + quantize_base=True, + ) + fixed_init_model(qdora_linear) return qdora_linear + return create_qdora_linear + def test_forward(self, inputs, dora_linear, out_dim) -> None: + dora_linear = dora_linear(use_bias=False, dtype=torch.float32) expected = torch.tensor(EXPECTED_VAL) actual = dora_linear(inputs) assert actual.shape == (BSZ, SEQ_LEN, out_dim) torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6) - def test_dora_weight_nf4_when_quantized(self, qdora_linear): + @pytest.mark.parametrize("use_bias", [True, False]) + def test_dora_weight_nf4_when_quantized(self, use_bias, qdora_linear): + qdora_linear = qdora_linear(use_bias=use_bias, dtype=torch.bfloat16) assert isinstance(qdora_linear.weight, NF4Tensor) - - def test_bias_raises(self): - with pytest.raises( - NotImplementedError, match="DoRALinear does not support using bias" - ): - DoRALinear( - in_dim=512, - out_dim=512, - rank=RANK, - alpha=ALPHA, - use_bias=True, - quantize_base=False, - ) - - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) - def test_qdora_parity(self, dtype): + if use_bias: + assert not isinstance(qdora_linear.bias, NF4Tensor) + assert qdora_linear.bias.dtype == torch.bfloat16 + + # Note: with bfloat16 F.linear(x, weight, bias) != F.linear(x, weight) + bias. + # This means we would get different results (irrespective of QDoRA). + # So we leave that test case out + @pytest.mark.parametrize( + "use_bias, dtype", + [(False, torch.bfloat16), (True, torch.float32), (False, torch.float32)], + ) + def test_qdora_parity(self, use_bias, dtype, dora_linear, qdora_linear): with training.set_default_dtype(dtype): - torch.manual_seed(0) - qdora_linear = DoRALinear( - in_dim=512, - out_dim=512, - rank=RANK, - alpha=ALPHA, - use_bias=False, - quantize_base=True, + qdora_linear = qdora_linear( + use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512 ) - torch.manual_seed(0) - dora_linear = DoRALinear( - in_dim=512, - out_dim=512, - rank=RANK, - alpha=ALPHA, - use_bias=False, - quantize_base=False, + dora_linear = dora_linear( + use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512 ) # set weight of dora_linear to unquantized weight of qdora_linear and check # parity. dora_linear.weight.data = qdora_linear.weight.to(dtype) - + if use_bias: + dora_linear.bias.data = qdora_linear.bias.detach().clone() qdora_linear.initialize_dora_magnitude() dora_linear.initialize_dora_magnitude() diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py index f9df68bd28..72061e0682 100644 --- a/tests/torchtune/modules/peft/test_lora.py +++ b/tests/torchtune/modules/peft/test_lora.py @@ -72,7 +72,7 @@ def create_qlora_linear(use_bias, dtype): use_bias=use_bias, quantize_base=True, ) - fixed_init_model(qlora_linear, dtype=torch.bfloat16) + fixed_init_model(qlora_linear) return qlora_linear return create_qlora_linear diff --git a/torchtune/models/llama3_2_vision/_component_builders.py b/torchtune/models/llama3_2_vision/_component_builders.py index 111393501d..8881c87531 100644 --- a/torchtune/models/llama3_2_vision/_component_builders.py +++ b/torchtune/models/llama3_2_vision/_component_builders.py @@ -4,31 +4,44 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from functools import partial from enum import Enum -from typing import Optional, List +from functools import partial +from typing import List, Optional from torch import nn +from torchtune.models.clip._component_builders import ( + clip_mlp, + clip_vision_encoder, + lora_clip_attention, + lora_clip_mlp, + lora_clip_vision_encoder, +) from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp -from torchtune.models.llama3_1._component_builders import llama3_mlp, lora_llama3_mlp, lora_llama3_attention +from torchtune.models.llama3_1._component_builders import ( + llama3_mlp, + lora_llama3_attention, + lora_llama3_mlp, +) from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE -from torchtune.models.clip._component_builders import clip_vision_encoder, clip_mlp, lora_clip_attention, lora_clip_mlp, lora_clip_vision_encoder -from torchtune.models.llama3_2_vision._encoder import Llama3VisionProjectionHead, Llama3VisionEncoder - -from torchtune.modules.model_fusion import FusionEmbedding, FusionLayer +from torchtune.models.llama3_2_vision._encoder import ( + Llama3VisionEncoder, + Llama3VisionProjectionHead, +) from torchtune.modules import ( + Fp32LayerNorm, + MultiHeadAttention, RMSNorm, TanhGate, TransformerCrossAttentionLayer, - MultiHeadAttention, TransformerDecoder, TransformerSelfAttentionLayer, - Fp32LayerNorm ) from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook +from torchtune.modules.model_fusion import FusionEmbedding, FusionLayer + from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear @@ -59,7 +72,7 @@ def llama3_2_vision_encoder( tile_size: int, max_num_tiles: int = 4, in_channels: int = 3, - ) -> Llama3VisionEncoder: +) -> Llama3VisionEncoder: """ Build the Llama 3.2 vision encoder by combining the CLIP image model with an additional projection head fusion module. This includes: @@ -76,7 +89,7 @@ def llama3_2_vision_encoder( clip_embed_dim (int): The dimensionality of each patch embedding in CLIP. clip_num_layers (int): The number of transformer layers. clip_hidden_states (Optional[List[int]]): The indices of CLIP hidden layers to return - to return to the encoder projection head. It will return the intermediate results + to return to the encoder projection head. It will return the intermediate results of the vision transformer layers which will be concatenated with the CLIP output and input into the projection head. For example, ``clip_hidden_states=[0,3]`` will return the embeddings before they go through the first and fourth layers. @@ -113,7 +126,7 @@ def llama3_2_vision_encoder( num_heads=num_heads, decoder_embed_dim=decoder_embed_dim, clip_embed_dim=clip_embed_dim, - num_hidden_inputs=len(clip_hidden_states or []) + num_hidden_inputs=len(clip_hidden_states or []), ) return Llama3VisionEncoder(clip=clip, projection_head=projection_head) @@ -239,6 +252,7 @@ def llama3_2_vision_decoder( output=output_proj, ) + def llama3_2_vision_projection_head( *, num_layers: int, @@ -306,9 +320,10 @@ def llama3_2_vision_projection_head( return Llama3VisionProjectionHead( layers=layers, output=nn.Linear(proj_in, decoder_embed_dim), - num_hidden_inputs=num_hidden_inputs + num_hidden_inputs=num_hidden_inputs, ) + # ------------------ LoRA Llama 3.2 Vision ------------------ @@ -344,7 +359,7 @@ def lora_llama3_2_vision_encoder( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, - ) -> Llama3VisionEncoder: +) -> Llama3VisionEncoder: """ Build the Llama 3.2 vision encoder by combining the CLIP image model with an additional projection head fusion module. This includes: @@ -370,7 +385,7 @@ def lora_llama3_2_vision_encoder( clip_embed_dim (int): The dimensionality of each patch embedding in CLIP. clip_num_layers (int): The number of transformer layers. clip_hidden_states (Optional[List[int]]): The indices of CLIP hidden layers to return - to return to the encoder projection head. It will return the intermediate results + to return to the encoder projection head. It will return the intermediate results of the vision transformer layers which will be concatenated with the CLIP output and input into the projection head. For example, ``clip_hidden_states=[0,3]`` will return the embeddings before they go through the first and fourth layers. @@ -388,7 +403,7 @@ def lora_llama3_2_vision_encoder( quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. - + Returns: Llama3VisionEncoder: Instantiation of Llama 3.2 vision encoder. @@ -423,7 +438,7 @@ def lora_llama3_2_vision_encoder( else: clip = clip_vision_encoder(**clip_options) - # Projection + # Projection projection_options = { "num_layers": num_layers_projection, "num_heads": num_heads, @@ -432,11 +447,22 @@ def lora_llama3_2_vision_encoder( "num_hidden_inputs": len(clip_hidden_states or []), } if fusion_lora: - projection_head = lora_llama3_2_vision_projection_head(**projection_options, **lora_options) + projection_head = lora_llama3_2_vision_projection_head( + **projection_options, **lora_options + ) else: projection_head = lora_llama3_2_vision_projection_head(**projection_options) - return Llama3VisionEncoder(clip=clip, projection_head=projection_head) + encoder = Llama3VisionEncoder(clip=clip, projection_head=projection_head) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly + # so as to not increase peak memory + encoder._register_state_dict_hook( + partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) + ) + + return encoder def lora_llama3_2_vision_decoder( @@ -458,7 +484,7 @@ def lora_llama3_2_vision_decoder( encoder_max_seq_len: int, rope_base: int = 500000.0, intermediate_dim: Optional[int] = None, - # LoRA parameters + # LoRA parameters lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, @@ -546,7 +572,9 @@ def lora_llama3_2_vision_decoder( use_dora=use_dora, ) else: - mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) + mlp = llama3_mlp( + dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base + ) decoder_layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, @@ -586,7 +614,9 @@ def lora_llama3_2_vision_decoder( use_dora=use_dora, ) else: - mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) + mlp = llama3_mlp( + dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base + ) xattn_layer = TransformerCrossAttentionLayer( attn=attn, mlp=mlp, @@ -601,11 +631,17 @@ def lora_llama3_2_vision_decoder( layers.append(decoder_layer) tok_embeddings = FusionEmbedding(vocab_size, num_special_tokens, embed_dim) - - # TODO: quantize_base is not applied to final output_proj currently. + + # TODO: quantize_base is not applied to final output_proj currently. adapter_cls = DoRALinear if use_dora else LoRALinear output_proj = ( - adapter_cls(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + adapter_cls( + embed_dim, + vocab_size, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + ) if apply_lora_to_output else nn.Linear(embed_dim, vocab_size, bias=False) ) @@ -713,7 +749,7 @@ def lora_llama3_2_vision_projection_head( hidden_dim=hidden_dim, out_dim=clip_embed_dim, activation=nn.GELU(), - quantize_base=quantize_base + quantize_base=quantize_base, ) layer = TransformerSelfAttentionLayer( @@ -733,7 +769,14 @@ def lora_llama3_2_vision_projection_head( proj_in = clip_embed_dim * (num_hidden_inputs + 1) adapter_cls = DoRALinear if use_dora else LoRALinear output_proj = ( - adapter_cls(proj_in, decoder_embed_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, use_bias=True) + adapter_cls( + proj_in, + decoder_embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + use_bias=True, + ) if apply_lora_to_output else nn.Linear(proj_in, decoder_embed_dim) ) diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 4768d77619..3791154ba1 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -259,6 +259,7 @@ def get_merged_lora_ckpt( # Otherwise it is just vanilla LoRA else: + print(f"module is {module}") state_dict[f"{module}.weight"] += ( (alpha / rank) * lora_b_weight @ lora_a_weight ) diff --git a/torchtune/modules/peft/dora.py b/torchtune/modules/peft/dora.py index e0a8fe9788..153b3c78e1 100644 --- a/torchtune/modules/peft/dora.py +++ b/torchtune/modules/peft/dora.py @@ -39,8 +39,6 @@ class DoRALinear(nn.Module, AdapterModule): quantize_base (bool): Whether to quantize base linear weight or not. Default: False - Raises: - NotImplementedError: If use_bias is enabled. """ def __init__( @@ -54,14 +52,16 @@ def __init__( quantize_base: bool = False, ): super().__init__() - if use_bias: - raise NotImplementedError("DoRALinear does not support using bias") self.in_dim = in_dim self.out_dim = out_dim self.scaling = alpha / rank + self.use_bias = use_bias self._quantize_base = quantize_base - weight = self._create_weight() + weight, bias = self._create_weight_and_bias() self.register_parameter("weight", nn.Parameter(weight)) + self.register_parameter( + "bias", nn.Parameter(bias) if bias is not None else None + ) # 'self.disabled' is a flag showing whether to turn off DoRA adapters, # this can be used in DPO for treating the dora adapters as the policy model @@ -90,15 +90,18 @@ def initialize_dora_magnitude(self): weight_norm = self._get_weight_norm(base_weight, lora_weight) self.magnitude = nn.Parameter(weight_norm, requires_grad=True) - def _create_weight(self): + def _create_weight_and_bias(self): """ Creates a linear weight and bias tensor, using NF4 dtype if we're quantizing (indicated via quantize_base=True). """ - in_dim, out_dim = self.in_dim, self.out_dim - linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=False) + in_dim, out_dim, use_bias = self.in_dim, self.out_dim, self.use_bias + linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=use_bias) weight = linear.weight if not self._quantize_base else to_nf4(linear.weight) - return weight + bias = None + if self.use_bias: + bias = linear.bias + return weight, bias def _get_weight_norm(self, weight, lora_weight): weight = weight + self.scaling * lora_weight @@ -123,8 +126,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ if self._quantize_base: base_out = linear_nf4(input=x, weight=self.weight) + if self.use_bias: + base_out = base_out + self.bias else: - base_out = F.linear(x, self.weight) + base_out = F.linear(x, self.weight, self.bias) if self.disabled: return base_out From b1973688f584faff7212af6c5e994ade52c9b466 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Thu, 24 Oct 2024 16:41:05 -0700 Subject: [PATCH 10/11] cleanup debug code, v nice state dict hook workaround --- torchtune/modules/model_fusion/_fusion.py | 9 +++++---- torchtune/modules/peft/_utils.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 1a5452daae..907c7a2ed0 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -232,10 +232,11 @@ def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): """Apply extra "embedding" prefix to the state_dict key to account for the FusionEmbedding wrapping. """ - key = prefix + "weight" - new_key = prefix + "embedding.weight" - state_dict[new_key] = state_dict[key] - del state_dict[key] + if state_dict: + key = prefix + "weight" + new_key = prefix + "embedding.weight" + state_dict[new_key] = state_dict[key] + del state_dict[key] def fusion_params(self) -> List[str]: """ diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 3791154ba1..4768d77619 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -259,7 +259,6 @@ def get_merged_lora_ckpt( # Otherwise it is just vanilla LoRA else: - print(f"module is {module}") state_dict[f"{module}.weight"] += ( (alpha / rank) * lora_b_weight @ lora_a_weight ) From 09491e61c7bf022a443ef8dc4942bf53f9b81906 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Thu, 24 Oct 2024 16:56:29 -0700 Subject: [PATCH 11/11] fix failing unit test --- tests/torchtune/modules/peft/test_lora.py | 48 +++++++++++------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py index 72061e0682..fc76adea30 100644 --- a/tests/torchtune/modules/peft/test_lora.py +++ b/tests/torchtune/modules/peft/test_lora.py @@ -50,23 +50,27 @@ def inputs(self, in_dim) -> torch.Tensor: @pytest.fixture def lora_linear(self, in_dim, out_dim) -> LoRALinear: - lora_linear = LoRALinear( - in_dim=in_dim, - out_dim=out_dim, - rank=RANK, - alpha=ALPHA, - use_bias=True, - ) - fixed_init_model(lora_linear) - return lora_linear + def create_lora_linear(use_bias, dtype, in_dim=in_dim, out_dim=out_dim): + with training.set_default_dtype(dtype): + lora_linear = LoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=RANK, + alpha=ALPHA, + use_bias=use_bias, + ) + fixed_init_model(lora_linear) + return lora_linear + + return create_lora_linear @pytest.fixture def qlora_linear(self): - def create_qlora_linear(use_bias, dtype): + def create_qlora_linear(use_bias, dtype, in_dim=512, out_dim=512): with training.set_default_dtype(dtype): qlora_linear = LoRALinear( - in_dim=512, - out_dim=512, + in_dim=in_dim, + out_dim=out_dim, rank=RANK, alpha=ALPHA, use_bias=use_bias, @@ -95,6 +99,7 @@ def set_dummy_weights_for_merge(self, lora_module): lora_module.lora_b.weight[32, 1] = 12 def test_forward(self, inputs, lora_linear, out_dim) -> None: + lora_linear = lora_linear(use_bias=True, dtype=torch.float32) expected = torch.tensor(EXPECTED_VAL) actual = lora_linear(inputs) assert actual.shape == (BSZ, SEQ_LEN, out_dim) @@ -115,18 +120,13 @@ def test_lora_weight_nf4_when_quantized(self, use_bias, qlora_linear): "use_bias, dtype", [(False, torch.bfloat16), (True, torch.float32), (False, torch.float32)], ) - def test_qlora_parity(self, use_bias, dtype, qlora_linear): - qlora_linear = qlora_linear(use_bias=use_bias, dtype=dtype) - with training.set_default_dtype(dtype): - lora_linear = LoRALinear( - in_dim=512, - out_dim=512, - rank=RANK, - alpha=ALPHA, - use_bias=use_bias, - quantize_base=False, - ) - fixed_init_model(lora_linear, dtype=torch.bfloat16) + def test_qlora_parity(self, use_bias, dtype, qlora_linear, lora_linear): + qlora_linear = qlora_linear( + use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512 + ) + lora_linear = lora_linear( + use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512 + ) # set weight of lora_linear to unquantized weight of qlora_linear and check # parity.