Skip to content

Commit

Permalink
Merge pull request #2 from ddlBoJack/debug-mzy-20231020
Browse files Browse the repository at this point in the history
Debug mzy 20231020
  • Loading branch information
LauraGPT authored Nov 29, 2023
2 parents 065b917 + e37aea4 commit f0db0a0
Show file tree
Hide file tree
Showing 13 changed files with 128 additions and 56 deletions.
32 changes: 0 additions & 32 deletions scripts/finetune.sh

This file was deleted.

67 changes: 67 additions & 0 deletions scripts/finetune_echat.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1
export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

# debug setting for multiple gpus
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
python src/llama_recipes/pipeline/finetune.py \
--model_name echat \
--use_peft --peft_method lora \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \
--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \
--batching_strategy custom \
--custom_dataset.max_words 1024 \
--num_epochs 100 \
--batch_size_training 2 \
--output_dir $output_dir \
--run_test_during_validation true \
--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M013.wav \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \

else
torchrun \
--nnodes 1 \
--nproc_per_node 2 \
src/llama_recipes/pipeline/finetune.py \
--model_name echat \
--enable_fsdp \
--use_peft --peft_method lora \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \
--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \
--batching_strategy custom \
--custom_dataset.max_words 1024 \
--num_epochs 100 \
--batch_size_training 8 \
--val_batch_size 8 \
--output_dir $output_dir \
--run_test_during_validation \
--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M013.wav \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
fi
6 changes: 3 additions & 3 deletions scripts/inference.sh → scripts/inference_echat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/zhifu.gzf/models/llama-2-hf-finetune
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
#python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
Expand All @@ -28,5 +28,5 @@ python src/llama_recipes/pipeline/inference.py \
--num_epochs 1 \
--batch_size_training 2 \
--output_dir $output_dir \
--ckpt_path "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt" \
--peft_ckpt "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0"
--ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/0/model.pt" \
--peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/0"
3 changes: 2 additions & 1 deletion src/llama_recipes/configs/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
class fsdp_config:
mixed_precision: bool=True
use_fp16: bool=False
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
# sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: ShardingStrategy = ShardingStrategy.NO_SHARD #MZY: set NO_SHARD to use DDP mode in FSDP
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool=True
fsdp_cpu_offload: bool=False
Expand Down
2 changes: 2 additions & 0 deletions src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@ class train_config:
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
save_optimizer: bool=False # will be used if using FSDP
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation: bool = False
run_test_during_validation_file: str = "test.wav"
6 changes: 3 additions & 3 deletions src/llama_recipes/datasets/echat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __getitem__(self, index):
speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
speech_pseudo = torch.full((speech_length,),-1)

example_ids = self.tokenizer.encode(answer) # FIX(GZF): [answer]
example_ids = prompt_ids + example_ids
example = prompt + answer #FIX(MZY): avoid putting a bos token before answer.
example_ids = self.tokenizer.encode(example) # [prompt,answer]
example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos]
example_ids = torch.tensor(
example_ids, dtype=torch.int64
Expand Down Expand Up @@ -152,7 +152,7 @@ def collator(self, samples):

speech_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
speech_mask[line, :sample['speech_length']] = 1 #FIX(GZF): sample['speech_length']+1
speech_mask[line, :sample['speech_length']] = 1

return {
'input_ids': input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def save_model_checkpoint(
print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")

def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
print(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch))
os.makedirs(save_dir, exist_ok=True)
model.llm.save_pretrained(save_dir)
Expand Down
18 changes: 7 additions & 11 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def setup_llm(train_config, model_config, **kwargs):
except ImportError:
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp else 0)

# Prepare the model for int8 training if quantization is enabled
if train_config.quantization:
Expand Down Expand Up @@ -135,12 +135,7 @@ def __init__(

# projector
self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0], self.llm.config.hidden_size)
ckpt_path = kwargs.get("ckpt_path", None)
# ckpt_path = kwargs.get("ckpt_path", "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt")
if ckpt_path is not None:
print("loading ckpt from: ", ckpt_path)
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
self.load_state_dict(ckpt_dict, strict=False)

# tokenizer
self.tokenizer = tokenizer

Expand All @@ -163,9 +158,10 @@ def forward(self,
speech_encoder_outs = None
if speech_mel is not None:
speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1))
speech_encoder_outs = self.speech_encoder_projector.to(speech_encoder_outs.device)(speech_encoder_outs)
speech_encoder_outs = self.speech_encoder_projector(speech_encoder_outs)

input_ids[input_ids == -1] = 0
# print(input_ids[0])
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(input_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
Expand All @@ -178,7 +174,7 @@ def forward(self,
speech_encoder_outs_pad = F.pad(speech_encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0)
inputs_embeds = speech_encoder_outs_pad * speech_mask[:, :, None] + inputs_embeds * (~speech_mask[:, :, None])

model_outputs = self.llm.to(speech_encoder_outs.device)(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)

return model_outputs

Expand All @@ -205,7 +201,7 @@ def generate(
speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0)[None, :, :].to(device)

speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1))
speech_encoder_outs = self.speech_encoder_projector.to(speech_encoder_outs.device)(speech_encoder_outs)
speech_encoder_outs = self.speech_encoder_projector(speech_encoder_outs)

prompt="""
Please provide an emotional response based on the emotional speech you hear.
Expand All @@ -214,7 +210,7 @@ def generate(
<|REPLY|> is a reply based on a the speech.
"""
prompt = "USER: {}\n ASSISTANT:".format(prompt)
prompt_ids = self.tokenizer.encode(prompt) # FIX(GZF)
prompt_ids = self.tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device)

Expand Down
3 changes: 3 additions & 0 deletions src/llama_recipes/pipeline/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# nn
import torch
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# opt
import torch.optim as optim
Expand Down Expand Up @@ -66,6 +67,8 @@ def main(**kwargs):
setup_environ_flags(rank)

model, tokenizer = model_factory(train_config, model_config, **kwargs)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device.
model.to(device)


# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
Expand Down
3 changes: 2 additions & 1 deletion src/llama_recipes/pipeline/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def main(**kwargs):
random.seed(train_config.seed)

model, tokenizer = model_factory(train_config, model_config, **kwargs)
model.to(kwargs.get("device", "cuda"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device.
model.to(device)
model.eval()

print("=====================================")
Expand Down
8 changes: 7 additions & 1 deletion src/llama_recipes/pipeline/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import torch
from llama_recipes.models.slam_model import setup_model, setup_tokenizer

def model_factory(train_config, model_config, **kwargs):

tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
model = setup_model(tokenizer, train_config, model_config, **kwargs).cuda()
model = setup_model(tokenizer, train_config, model_config, **kwargs)
ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft)
if ckpt_path is not None:
print("loading ckpt from: ", ckpt_path)
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt_dict, strict=False)

return model, tokenizer
7 changes: 7 additions & 0 deletions src/llama_recipes/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs["collate_fn"] = default_data_collator
else:
# raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
if train_config.enable_fsdp:
kwargs["sampler"] = DistributedSampler(
dataset,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode=="train",
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = dataset.collator
Expand Down
28 changes: 24 additions & 4 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
print(f"we are about to save the PEFT modules")
else:
print(f"we are about to save the PEFT modules")
# model.save_pretrained(train_config.output_dir)
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
)
if train_config.enable_fsdp:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
)
dist.barrier()
else:
# model.save_pretrained(train_config.output_dir)
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
Expand Down Expand Up @@ -189,6 +196,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
val_loss.append(best_val_loss)
val_prep.append(eval_ppl)
if train_config.run_test_during_validation:
if train_config.enable_fsdp:
if rank==0:
print("=====================================")
print(f"Test the file {train_config.run_test_during_validation_file} during validation:")
print(model.generate(train_config.run_test_during_validation_file))
print("=====================================")
dist.barrier()
else:
print("=====================================")
print(f"Test the file {train_config.run_test_during_validation_file} during validation:")
print(model.generate(train_config.run_test_during_validation_file))
print("=====================================")
if train_config.enable_fsdp:
if rank==0:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
Expand Down

0 comments on commit f0db0a0

Please sign in to comment.