Skip to content

Commit

Permalink
Merge pull request #5 from ddlBoJack/dev-mzy
Browse files Browse the repository at this point in the history
some FSDP support
  • Loading branch information
ddlBoJack authored Dec 5, 2023
2 parents 3d6216e + 552b32d commit 023a797
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
4 changes: 2 additions & 2 deletions scripts/finetune_echat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ torchrun \
src/llama_recipes/pipeline/finetune.py \
--model_name echat \
--freeze_encoder \
--freeze_llm \
--use_fp16 \
--enable_fsdp \
--use_peft --peft_method lora \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
Expand All @@ -92,5 +92,5 @@ src/llama_recipes/pipeline/finetune.py \
--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_M040.wav \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1"
# --freeze_llm \
# --use_peft --peft_method lora \
fi
28 changes: 22 additions & 6 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import soundfile as sf
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from typing import List, Optional, Tuple, Union
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import (
Expand Down Expand Up @@ -83,6 +84,7 @@ def setup_llm(train_config, model_config, **kwargs):
if not verify_latest_nightly:
raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
"please install latest nightly.")
rank = int(os.environ["RANK"])
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
Expand All @@ -97,12 +99,26 @@ def setup_llm(train_config, model_config, **kwargs):
model = LlamaForCausalLM(llama_config)

else:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
if train_config.enable_fsdp:
rank = int(os.environ["RANK"])
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
else:
llama_config = LlamaConfig.from_pretrained(model_config.llm_path)
llama_config.use_cache = use_cache
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
if train_config.enable_fsdp and train_config.use_fast_kernels:
"""
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
Expand Down
2 changes: 1 addition & 1 deletion src/llama_recipes/pipeline/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(**kwargs):

model = FSDP(
model,
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
auto_wrap_policy= my_auto_wrapping_policy, #(FIX:MZY): Using my_auto_wrapping_policy whether peft or not. This will avoid model shard type check error of requires_grad mismatching.
cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
Expand Down

0 comments on commit 023a797

Please sign in to comment.