Skip to content

Commit

Permalink
run ddp mode in fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
ddlBoJack committed Nov 28, 2023
1 parent 81345c8 commit e37aea4
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 37 deletions.
31 changes: 0 additions & 31 deletions scripts/finetune.sh

This file was deleted.

67 changes: 67 additions & 0 deletions scripts/finetune_echat.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1
export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

# debug setting for multiple gpus
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
python src/llama_recipes/pipeline/finetune.py \
--model_name echat \
--use_peft --peft_method lora \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \
--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \
--batching_strategy custom \
--custom_dataset.max_words 1024 \
--num_epochs 100 \
--batch_size_training 2 \
--output_dir $output_dir \
--run_test_during_validation true \
--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M013.wav \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \

else
torchrun \
--nnodes 1 \
--nproc_per_node 2 \
src/llama_recipes/pipeline/finetune.py \
--model_name echat \
--enable_fsdp \
--use_peft --peft_method lora \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \
--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \
--batching_strategy custom \
--custom_dataset.max_words 1024 \
--num_epochs 100 \
--batch_size_training 8 \
--val_batch_size 8 \
--output_dir $output_dir \
--run_test_during_validation \
--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M013.wav \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
fi
File renamed without changes.
3 changes: 2 additions & 1 deletion src/llama_recipes/configs/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
class fsdp_config:
mixed_precision: bool=True
use_fp16: bool=False
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
# sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: ShardingStrategy = ShardingStrategy.NO_SHARD #MZY: set NO_SHARD to use DDP mode in FSDP
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool=True
fsdp_cpu_offload: bool=False
Expand Down
2 changes: 2 additions & 0 deletions src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@ class train_config:
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
save_optimizer: bool=False # will be used if using FSDP
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation: bool = False
run_test_during_validation_file: str = "test.wav"
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def save_model_checkpoint(
print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")

def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
print(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch))
os.makedirs(save_dir, exist_ok=True)
model.llm.save_pretrained(save_dir)
Expand Down
3 changes: 2 additions & 1 deletion src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def setup_llm(train_config, model_config, **kwargs):
except ImportError:
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp else 0)

# Prepare the model for int8 training if quantization is enabled
if train_config.quantization:
Expand Down Expand Up @@ -161,6 +161,7 @@ def forward(self,
speech_encoder_outs = self.speech_encoder_projector(speech_encoder_outs)

input_ids[input_ids == -1] = 0
# print(input_ids[0])
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(input_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
Expand Down
1 change: 1 addition & 0 deletions src/llama_recipes/pipeline/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# nn
import torch
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# opt
import torch.optim as optim
Expand Down
7 changes: 7 additions & 0 deletions src/llama_recipes/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs["collate_fn"] = default_data_collator
else:
# raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
if train_config.enable_fsdp:
kwargs["sampler"] = DistributedSampler(
dataset,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode=="train",
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = dataset.collator
Expand Down
28 changes: 24 additions & 4 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
print(f"we are about to save the PEFT modules")
else:
print(f"we are about to save the PEFT modules")
# model.save_pretrained(train_config.output_dir)
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
)
if train_config.enable_fsdp:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
)
dist.barrier()
else:
# model.save_pretrained(train_config.output_dir)
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
Expand Down Expand Up @@ -189,6 +196,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
val_loss.append(best_val_loss)
val_prep.append(eval_ppl)
if train_config.run_test_during_validation:
if train_config.enable_fsdp:
if rank==0:
print("=====================================")
print(f"Test the file {train_config.run_test_during_validation_file} during validation:")
print(model.generate(train_config.run_test_during_validation_file))
print("=====================================")
dist.barrier()
else:
print("=====================================")
print(f"Test the file {train_config.run_test_during_validation_file} during validation:")
print(model.generate(train_config.run_test_during_validation_file))
print("=====================================")
if train_config.enable_fsdp:
if rank==0:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
Expand Down

0 comments on commit e37aea4

Please sign in to comment.