Skip to content

Commit

Permalink
add freeze_LLM_only option for mllama finetuning (meta-llama#791)
Browse files Browse the repository at this point in the history
  • Loading branch information
wukaixingxp authored Nov 20, 2024
2 parents 39bfabf + d31ee18 commit e5662e5
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 5 deletions.
1 change: 1 addition & 0 deletions recipes/quickstart/finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ It lets us specify the training settings for everything from `model_name` to `da
output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False
num_freeze_layers: int = 1
freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
quantization: str = None
one_gpu: bool = False
save_model: bool = True
Expand Down
6 changes: 6 additions & 0 deletions recipes/quickstart/finetuning/finetune_vision_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ For **LoRA finetuning with FSDP**, we can run the following code:
```bash
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --use_peft --peft_method lora
```

For **finetuning with LLM freeze using FSDP**, we can run the following code:

```bash
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --freeze_LLM_only True
```
**Note**: `--batching_strategy padding` is needed as the vision model will not work with `packing` method.

For more details about the finetuning configurations, please read the [finetuning readme](./README.md).
Expand Down
1 change: 1 addition & 0 deletions src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class train_config:
output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False
num_freeze_layers: int = 1
freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
quantization: str = None
one_gpu: bool = False
save_model: bool = True
Expand Down
21 changes: 18 additions & 3 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
from llama_recipes.utils.train_utils import (
clear_gpu_cache,
freeze_transformer_layers,
freeze_LLM_only,
get_policies,
print_model_size,
print_frozen_model_status,
setup,
setup_environ_flags,
train,
Expand Down Expand Up @@ -194,7 +196,7 @@ def main(**kwargs):
model.resize_token_embeddings(len(tokenizer))

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if (
train_config.enable_fsdp
Expand Down Expand Up @@ -235,7 +237,14 @@ def main(**kwargs):

if not train_config.use_peft and train_config.freeze_layers:
freeze_transformer_layers(model, train_config.num_freeze_layers)

# print model size and frozen layers after freezing layers
print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)

if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
freeze_LLM_only(model)
# print model size and frozen layers after freezing layers
print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
if is_vision:
Expand All @@ -255,6 +264,11 @@ def main(**kwargs):
device_id = torch.xpu.current_device()
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()

if train_config.freeze_LLM_only:
use_orig_params = True
else:
use_orig_params = False
model = FSDP(
model,
auto_wrap_policy=(
Expand Down Expand Up @@ -282,6 +296,7 @@ def main(**kwargs):
if train_config.low_cpu_fsdp and rank != 0
else None
),
use_orig_params=use_orig_params,
)
if fsdp_config.fsdp_activation_checkpointing:
model.enable_input_require_grads()
Expand All @@ -297,7 +312,7 @@ def main(**kwargs):
dataset_processer = processor
else:
dataset_processer = tokenizer

# Load and preprocess the dataset for training and validation

dataset_train = get_preprocessed_dataset(
Expand Down
58 changes: 56 additions & 2 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
if i < num_layer:
for param in layer.parameters():
param.requires_grad = False


def freeze_LLM_only(model):
"""
Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
"""
for name, param in model.language_model.named_parameters():
param.requires_grad = False
for i, layer in enumerate(model.language_model.model.layers):
if i in model.language_model.model.cross_attention_layers:
for param in layer.parameters():
param.requires_grad = True

def check_frozen_layers_peft_model(model):
for i, layer in enumerate(model.base_model.model.model.layers):
Expand Down Expand Up @@ -476,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")

def print_frozen_model_status(model, config, rank: int = 0) -> None:
"""
Print the frozen status of the model's and the number of trainable parameters after frozen.

Args:
model: The PyTorch model.
model_name (str): Name of the model.
rank (int, optional): Current process's rank. Defaults to 0.
"""
if rank == 0:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("After freezing the model:")
print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n")

module_states = {}
# Iterate over all parameters
for name, param in model.named_parameters():
# Extract the top-level module name (e.g., "vision_model", "language_model")
top_module = name.split(".")[0]

# Initialize a record for the top-level module
if top_module not in module_states:
module_states[top_module] = {"frozen": [], "unfrozen": []}

# Group parameters into frozen or unfrozen
if param.requires_grad:
module_states[top_module]["unfrozen"].append(name)
else:
module_states[top_module]["frozen"].append(name)

print("--> Model state after freezing:")
# Analyze and print the results
for module, states in module_states.items():
frozen_params = states["frozen"]
unfrozen_params = states["unfrozen"]

if frozen_params and unfrozen_params:
# Mixed state: both frozen and unfrozen parameters
print(f" {module}: Mixed")
elif frozen_params:
# All parameters are frozen
print(f" {module}: Frozen")
else:
# All parameters are unfrozen
print(f" {module}: Unfrozen")
print("")

def get_policies(cfg, rank):
"""Get the policies for mixed precision and fsdp wrapping"""
Expand Down

0 comments on commit e5662e5

Please sign in to comment.