From 093785170680f5a677ec690f70569317f89f0ace Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Sun, 3 Dec 2023 15:11:40 +0800 Subject: [PATCH 1/2] update echat dataset, freeze_llm --- scripts/finetune_echat.sh | 56 +++++++++++++------ scripts/inference_echat.sh | 14 +++-- src/llama_recipes/configs/datasets.py | 3 +- src/llama_recipes/configs/training.py | 1 + src/llama_recipes/datasets/echat_dataset.py | 38 ++++++++----- .../model_checkpointing/checkpoint_handler.py | 3 +- src/llama_recipes/models/slam_model.py | 13 ++++- src/llama_recipes/pipeline/inference.py | 2 +- src/llama_recipes/utils/train_utils.py | 22 +++++++- 9 files changed, 108 insertions(+), 44 deletions(-) diff --git a/scripts/finetune_echat.sh b/scripts/finetune_echat.sh index d4c1b381..2ae72b87 100644 --- a/scripts/finetune_echat.sh +++ b/scripts/finetune_echat.sh @@ -1,8 +1,9 @@ #!/bin/bash #export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0,1 +export CUDA_VISIBLE_DEVICES=0 export CUDA_LAUNCH_BLOCKING=1 -export OMP_NUM_THREADS=1 +# export OMP_NUM_THREADS=1 +# export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 # debug setting for multiple gpus # export NCCL_DEBUG=INFO @@ -11,15 +12,17 @@ export OMP_NUM_THREADS=1 cd /root/SLAM-LLM -speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt +# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune +output_dir=/nfs/maziyang.mzy/models/llama-2-hf-proj2048 # -m debugpy --listen 5678 --wait-for-client if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then -python src/llama_recipes/pipeline/finetune.py \ +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ --model_name echat \ ---use_peft --peft_method lora \ +--freeze_llm \ +--use_fp16 \ --llm_name llama-2-7b-hf \ --llm_path $llm_path \ --encoder_name whisper \ @@ -32,11 +35,29 @@ python src/llama_recipes/pipeline/finetune.py \ --custom_dataset.max_words 1024 \ --num_epochs 100 \ --batch_size_training 2 \ +--val_batch_size 2 \ --output_dir $output_dir \ ---run_test_during_validation true \ ---run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M013.wav \ -# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \ -# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \ +--run_test_during_validation \ +--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_M040.wav \ +# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1" +# --use_peft --peft_method lora \ + +# train +# {"trans": "Well, do you have your passport?\n", +# "emotion": "xxx", +# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_F009.wav"} +# {"trans": "No, I don't have a passport.\n", +# "emotion": "neu", +# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M010.wav"} + +# val +# {"trans": "Yeah, well thanks for your help.\n", +# "emotion": "ang", +# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_M040.wav"} +# {"trans": "I'm sorry. Good luck, man.\n", +# "emotion": "xxx", +# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_F038.wav"} else torchrun \ @@ -44,8 +65,8 @@ torchrun \ --nproc_per_node 2 \ src/llama_recipes/pipeline/finetune.py \ --model_name echat \ ---enable_fsdp \ ---use_peft --peft_method lora \ +--freeze_llm \ +--enable_fsdp --fsdp_config.pure_bf16 \ --llm_name llama-2-7b-hf \ --llm_path $llm_path \ --encoder_name whisper \ @@ -57,11 +78,12 @@ src/llama_recipes/pipeline/finetune.py \ --batching_strategy custom \ --custom_dataset.max_words 1024 \ --num_epochs 100 \ ---batch_size_training 8 \ ---val_batch_size 8 \ +--batch_size_training 4 \ +--val_batch_size 4 \ --output_dir $output_dir \ --run_test_during_validation \ --run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M013.wav \ -# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \ -# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \ -fi \ No newline at end of file +# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1" +# --use_peft --peft_method lora \ +fi diff --git a/scripts/inference_echat.sh b/scripts/inference_echat.sh index 49572adc..87550aca 100644 --- a/scripts/inference_echat.sh +++ b/scripts/inference_echat.sh @@ -1,11 +1,12 @@ #!/bin/bash #export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=1 export CUDA_LAUNCH_BLOCKING=1 cd /root/SLAM-LLM -speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt +# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt +speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune @@ -13,7 +14,8 @@ output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune #python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ python src/llama_recipes/pipeline/inference.py \ --model_name echat \ ---use_peft --peft_method lora \ +--freeze_llm \ +--use_fp16 \ --quantization \ --llm_name llama-2-7b-hf \ --llm_path $llm_path \ @@ -28,5 +30,7 @@ python src/llama_recipes/pipeline/inference.py \ --num_epochs 1 \ --batch_size_training 2 \ --output_dir $output_dir \ ---ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/0/model.pt" \ ---peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/0" \ No newline at end of file +--ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \ +--wav_path "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_F035.wav" +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1" +# --use_peft --peft_method lora \ \ No newline at end of file diff --git a/src/llama_recipes/configs/datasets.py b/src/llama_recipes/configs/datasets.py index ab7e34ec..45c232aa 100644 --- a/src/llama_recipes/configs/datasets.py +++ b/src/llama_recipes/configs/datasets.py @@ -33,4 +33,5 @@ class custom_dataset: train_split: str = "train" test_split: str = "validation" data_path: str = NotImplemented - max_words: int = NotImplemented \ No newline at end of file + max_words: int = NotImplemented + max_mel: int = 1000 \ No newline at end of file diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 8a24669e..86ab1268 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -38,3 +38,4 @@ class train_config: use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels run_test_during_validation: bool = False run_test_during_validation_file: str = "test.wav" + freeze_llm: bool = False diff --git a/src/llama_recipes/datasets/echat_dataset.py b/src/llama_recipes/datasets/echat_dataset.py index 77b0c598..8723cd23 100644 --- a/src/llama_recipes/datasets/echat_dataset.py +++ b/src/llama_recipes/datasets/echat_dataset.py @@ -23,7 +23,7 @@ def __init__( super().__init__() self.dataset_config = dataset_config - self.max_words = dataset_config.max_words + self.max_mel = dataset_config.max_mel self.tokenizer = tokenizer self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss self.prompt_template = "USER: {}\n ASSISTANT:" @@ -31,10 +31,27 @@ def __init__( with open(dataset_config.data_path, 'r') as file: data = file.readlines() + + sentence_list = [] + + for item in data: + dialog_name, dialog = item.split('\t', 1) + dialog_list = eval(dialog) + for sentence_id in range(len(dialog_list)-2): + if 'emotion' in dialog_list[sentence_id].keys() and 'emotion' in dialog_list[sentence_id+1].keys(): + if dialog_list[sentence_id+1]['emotion'] != 'xxx': + sentence_dict = {} + sentence_dict['pre_wav'] = dialog_list[sentence_id]['wav'] + sentence_dict['post_emotion'] = dialog_list[sentence_id+1]['emotion'] + sentence_dict['post_trans'] = dialog_list[sentence_id+1]['trans'] + sentence_list.append(sentence_dict) + + total_sentence = len(sentence_list) + print(f"Using {total_sentence} sentence totally.") if split == "train": - self.data = data[:60] + self.data = sentence_list[:int(total_sentence * 0.9)] else: - self.data = data[60:] + self.data = sentence_list[int(total_sentence * 0.9):] def __len__(self) -> int: @@ -42,17 +59,10 @@ def __len__(self) -> int: def __getitem__(self, index): item = self.data[index] - dialog_name, dialog = item.split('\t', 1) - dialog_list = eval(dialog) - - while True: - sentence_id = random.randint(0, len(dialog_list)-2) - if 'emotion' in dialog_list[sentence_id].keys() and 'emotion' in dialog_list[sentence_id+1].keys(): - if dialog_list[sentence_id]['emotion'] != 'xxx' and dialog_list[sentence_id+1]['emotion'] != 'xxx': - break - speech_raw = whisper.load_audio(dialog_list[sentence_id]['wav']) + + speech_raw = whisper.load_audio(item['pre_wav']) # speech_raw = whisper.pad_or_trim(speech_raw) - speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0) + speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0)[:self.max_mel] prompt=""" Please provide an emotional response based on the emotional speech you hear. @@ -65,7 +75,7 @@ def __getitem__(self, index): """ prompt = self.prompt_template.format(prompt) - answer = self.answer_template.format(dialog_list[sentence_id+1]['emotion'], dialog_list[sentence_id+1]['trans']) + answer = self.answer_template.format(item['post_emotion'], item['post_trans']) prompt_ids = self.tokenizer.encode(prompt) diff --git a/src/llama_recipes/model_checkpointing/checkpoint_handler.py b/src/llama_recipes/model_checkpointing/checkpoint_handler.py index c124c13a..a716d398 100644 --- a/src/llama_recipes/model_checkpointing/checkpoint_handler.py +++ b/src/llama_recipes/model_checkpointing/checkpoint_handler.py @@ -164,7 +164,8 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): print(f"--> saving model ...") save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch)) os.makedirs(save_dir, exist_ok=True) - model.llm.save_pretrained(save_dir) + if not cfg.freeze_llm: + model.llm.save_pretrained(save_dir) save_full_path = os.path.join(save_dir, "model.pt") cpu_state = model.state_dict() diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index 6169d48a..7ba686c5 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -108,7 +108,7 @@ def setup_llm(train_config, model_config, **kwargs): model.print_trainable_parameters() if kwargs.get("peft_ckpt", None): - print("loading ckpt from: ", kwargs.get("peft_ckpt")) + print("loading peft_ckpt from: ", kwargs.get("peft_ckpt")) model = PeftModel.from_pretrained(model, kwargs.get("peft_ckpt")) return model @@ -132,9 +132,16 @@ def __init__( # llama self.llm = setup_llm(train_config, model_config, **kwargs) + if train_config.freeze_llm: + for name, param in self.llm.named_parameters(): + param.requires_grad = False # projector - self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0], self.llm.config.hidden_size) + self.speech_encoder_projector = nn.Sequential( + nn.Linear(self.speech_encoder.ln_post.normalized_shape[0], 2048), + nn.ReLU(), + nn.Linear(2048, self.llm.config.hidden_size), + ) # tokenizer self.tokenizer = tokenizer @@ -229,7 +236,7 @@ def generate( output = self.llm.generate( inputs_embeds=inputs_embeds, max_length=kwargs.get("max_length", 200), - num_beams=kwargs.get("num_beams", 1), + num_beams=kwargs.get("num_beams", 4), do_sample=kwargs.get("do_sample", True), min_length=kwargs.get("min_length", 1), top_p=kwargs.get("top_p", 0.9), diff --git a/src/llama_recipes/pipeline/inference.py b/src/llama_recipes/pipeline/inference.py index 935d7cdf..106286f1 100644 --- a/src/llama_recipes/pipeline/inference.py +++ b/src/llama_recipes/pipeline/inference.py @@ -34,7 +34,7 @@ def main(**kwargs): print("=====================================") # wav_path = input("Your Wav Path:\n") # prompt = input("Your Prompt:\n") - wav_path = "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M001.wav" + wav_path = kwargs.get('wav_path') print(model.generate(wav_path)) diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 9fb8ef5a..710b63eb 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -18,7 +18,12 @@ from transformers import LlamaTokenizer -from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_model_checkpoint_peft +from llama_recipes.model_checkpointing import( + save_model_checkpoint, + save_model_and_optimizer_sharded, + save_optimizer_checkpoint, + save_model_checkpoint_peft +) from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper from llama_recipes.utils.memory_utils import MemoryTrace @@ -135,7 +140,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.run_validation: eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) checkpoint_start_time = time.perf_counter() - if train_config.save_model and eval_epoch_loss < best_val_loss: + if train_config.save_model: if train_config.enable_fsdp: dist.barrier() if train_config.use_peft: @@ -160,6 +165,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print(f"PEFT modules are saved in {train_config.output_dir} directory") else: print(f"PEFT modules are saved in {train_config.output_dir} directory") + + elif not train_config.use_peft and train_config.freeze_llm: + print(f"llm is frozen, we are about to save other parts.") + if train_config.enable_fsdp: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) + dist.barrier() + else: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) else: if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: From 3d6216e910e0155745f04f367b9e08f6dcba5873 Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Mon, 4 Dec 2023 14:30:51 +0800 Subject: [PATCH 2/2] update encoder_projector downsample, fp16 --- scripts/finetune_echat.sh | 23 +++-- src/llama_recipes/configs/model.py | 4 +- src/llama_recipes/configs/training.py | 1 + src/llama_recipes/datasets/echat_dataset.py | 12 ++- src/llama_recipes/datasets/speech_dataset.py | 3 +- .../model_checkpointing/checkpoint_handler.py | 6 +- src/llama_recipes/models/slam_model.py | 92 +++++++++++++------ src/llama_recipes/utils/compute_utils.py | 3 + src/llama_recipes/utils/metric.py | 11 +-- src/llama_recipes/utils/train_utils.py | 54 ++++++++--- 10 files changed, 146 insertions(+), 63 deletions(-) create mode 100644 src/llama_recipes/utils/compute_utils.py diff --git a/scripts/finetune_echat.sh b/scripts/finetune_echat.sh index 2ae72b87..c7fd9999 100644 --- a/scripts/finetune_echat.sh +++ b/scripts/finetune_echat.sh @@ -1,6 +1,6 @@ #!/bin/bash #export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=0,1 export CUDA_LAUNCH_BLOCKING=1 # export OMP_NUM_THREADS=1 # export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 @@ -15,19 +15,22 @@ cd /root/SLAM-LLM # speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/maziyang.mzy/models/llama-2-hf-proj2048 +output_dir=/nfs/maziyang.mzy/models/llama-2-hf-proj2048-debug # -m debugpy --listen 5678 --wait-for-client if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ --model_name echat \ +--freeze_encoder \ --freeze_llm \ --use_fp16 \ --llm_name llama-2-7b-hf \ --llm_path $llm_path \ --encoder_name whisper \ +--encoder_ds_rate 2 \ --encoder_path $speech_encoder_path \ --encoder_projector linear \ +--encoder_projector_ds_rate 5 \ --dataset custom_dataset \ --custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \ --custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \ @@ -65,25 +68,29 @@ torchrun \ --nproc_per_node 2 \ src/llama_recipes/pipeline/finetune.py \ --model_name echat \ ---freeze_llm \ ---enable_fsdp --fsdp_config.pure_bf16 \ +--freeze_encoder \ +--use_fp16 \ +--enable_fsdp \ +--use_peft --peft_method lora \ --llm_name llama-2-7b-hf \ --llm_path $llm_path \ --encoder_name whisper \ +--encoder_ds_rate 2 \ --encoder_path $speech_encoder_path \ --encoder_projector linear \ +--encoder_projector_ds_rate 5 \ --dataset custom_dataset \ --custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \ --custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \ --batching_strategy custom \ --custom_dataset.max_words 1024 \ --num_epochs 100 \ ---batch_size_training 4 \ ---val_batch_size 4 \ +--batch_size_training 2 \ +--val_batch_size 2 \ --output_dir $output_dir \ --run_test_during_validation \ ---run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M013.wav \ +--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_M040.wav \ # --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \ # --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1" -# --use_peft --peft_method lora \ +# --freeze_llm \ fi diff --git a/src/llama_recipes/configs/model.py b/src/llama_recipes/configs/model.py index 8b869b79..46c8668c 100644 --- a/src/llama_recipes/configs/model.py +++ b/src/llama_recipes/configs/model.py @@ -6,5 +6,7 @@ class model_config: llm_name: str = "llama-2-7b-hf" llm_path: str = "PATH/to/LLAMA/7B" encoder_name: str = None + encoder_ds_rate: int = 2 encoder_path: str = None - encoder_projector: str = "linear" \ No newline at end of file + encoder_projector: str = "linear" + encoder_projector_ds_rate: int = 5 \ No newline at end of file diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 86ab1268..51606987 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -39,3 +39,4 @@ class train_config: run_test_during_validation: bool = False run_test_during_validation_file: str = "test.wav" freeze_llm: bool = False + freeze_encoder: bool = False diff --git a/src/llama_recipes/datasets/echat_dataset.py b/src/llama_recipes/datasets/echat_dataset.py index 8723cd23..d9fcdec4 100644 --- a/src/llama_recipes/datasets/echat_dataset.py +++ b/src/llama_recipes/datasets/echat_dataset.py @@ -11,6 +11,7 @@ import torchaudio from torch.utils.data import Dataset import whisper +from llama_recipes.utils.compute_utils import calculate_output_length_1d class EChatDataset(Dataset): @@ -48,10 +49,16 @@ def __init__( total_sentence = len(sentence_list) print(f"Using {total_sentence} sentence totally.") + # if split == "train": + # self.data = sentence_list[:int(total_sentence * 0.9)] + # else: + # self.data = sentence_list[int(total_sentence * 0.9):] + + # debug if split == "train": - self.data = sentence_list[:int(total_sentence * 0.9)] + self.data = sentence_list[:8] else: - self.data = sentence_list[int(total_sentence * 0.9):] + self.data = sentence_list[8:16] def __len__(self) -> int: @@ -81,6 +88,7 @@ def __getitem__(self, index): prompt_length = len(prompt_ids) speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + speech_length = calculate_output_length_1d(speech_length, 5, 5) # ad-hoc for 5x cov1d downsample speech_pseudo = torch.full((speech_length,),-1) example = prompt + answer #FIX(MZY): avoid putting a bos token before answer. diff --git a/src/llama_recipes/datasets/speech_dataset.py b/src/llama_recipes/datasets/speech_dataset.py index e6c6e4bb..126c3d84 100644 --- a/src/llama_recipes/datasets/speech_dataset.py +++ b/src/llama_recipes/datasets/speech_dataset.py @@ -39,7 +39,7 @@ def __init__(self, if split == "train": self.data_list = contents[:-1000] else: - self.data_list = contents[1000:] + self.data_list = contents[-1000:] def get_source_len(self, data_dict): return data_dict["source_len"] @@ -76,6 +76,7 @@ def __getitem__(self, index): prompt_length = len(prompt_ids) speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + speech_length = calculate_output_length_1d(speech_length, 5, 5) # ad-hoc for 5x cov1d downsample speech_pseudo = torch.full((speech_length,), -1) example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. diff --git a/src/llama_recipes/model_checkpointing/checkpoint_handler.py b/src/llama_recipes/model_checkpointing/checkpoint_handler.py index a716d398..d1d6eed6 100644 --- a/src/llama_recipes/model_checkpointing/checkpoint_handler.py +++ b/src/llama_recipes/model_checkpointing/checkpoint_handler.py @@ -170,8 +170,12 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): save_full_path = os.path.join(save_dir, "model.pt") cpu_state = model.state_dict() project_dict = {} + if not cfg.freeze_encoder: + for key in cpu_state.keys(): + if key.startswith("encoder."): + project_dict[key] = cpu_state[key] for key in cpu_state.keys(): - if "_projector" in key: + if key.startswith("encoder_projector."): project_dict[key] = cpu_state[key] torch.save(project_dict, save_full_path) diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index 84d5051e..e309565b 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -51,6 +51,23 @@ def extract_variable_length_features(self, x: torch.Tensor): x = self.ln_post(x) return x +def setup_encoder(train_config, model_config, **kwargs): + encoder_list = model_config.encoder_name.split(",") + if len(encoder_list) == 1: + encoder_name = encoder_list[0] + if encoder_name == "whisper" or "qwen-audio": + encoder = whisper.load_model(model_config.encoder_path).encoder + encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder) + if encoder_name == "audio-mae": #TODO + pass + + if train_config.freeze_encoder: + for name, param in encoder.named_parameters(): + param.requires_grad = False + encoder.eval() + + return encoder + def setup_llm(train_config, model_config, **kwargs): from pkg_resources import packaging use_cache = False if train_config.enable_fsdp else None @@ -104,6 +121,11 @@ def setup_llm(train_config, model_config, **kwargs): if train_config.quantization: model = prepare_model_for_kbit_training(model) + if train_config.freeze_llm: # TODO:to test offical `freeze_layers` and `num_freeze_layers` + for name, param in model.named_parameters(): + param.requires_grad = False + model.eval() + if train_config.use_peft: peft_config = generate_peft_config(train_config, kwargs) model = get_peft_model(model, peft_config) @@ -115,6 +137,29 @@ def setup_llm(train_config, model_config, **kwargs): return model +def setup_encoder_projector(train_config, model_config, **kwargs): + if model_config.encoder_projector == "linear": + return EncoderProjector(model_config) + +class EncoderProjector(nn.Module): + def __init__(self, config): + super(EncoderProjector, self).__init__() + self.conv1d = nn.Conv1d(in_channels=1280, out_channels=1280, kernel_size=config.encoder_ds_rate, stride=config.encoder_ds_rate, padding=0) + self.linear1 = nn.Linear(1280, 2048) + self.relu1 = nn.ReLU() + self.linear2 = nn.Linear(2048, 4096) + self.relu2 = nn.ReLU() + + def forward(self, x): + x = x.transpose(1, 2) + x = self.conv1d(x) + x = x.transpose(1, 2) + x = self.relu1(x) + x = self.linear1(x) + x = self.relu2(x) + x = self.linear2(x) + return x + class slam_model(nn.Module): def __init__( @@ -125,30 +170,22 @@ def __init__( **kwargs ): super().__init__() - # whisper - self.speech_encoder = whisper.load_model(model_config.encoder_path).encoder - self.speech_encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, self.speech_encoder) - for name, param in self.speech_encoder.named_parameters(): - param.requires_grad = False - self.speech_encoder.eval() - - # llama + # modality encoder + self.encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm self.llm = setup_llm(train_config, model_config, **kwargs) - if train_config.freeze_llm: - for name, param in self.llm.named_parameters(): - param.requires_grad = False # projector - self.speech_encoder_projector = nn.Sequential( - nn.Linear(self.speech_encoder.ln_post.normalized_shape[0], 2048), - nn.ReLU(), - nn.Linear(2048, self.llm.config.hidden_size), - ) + self.encoder_projector = setup_encoder_projector(train_config, model_config, **kwargs) # tokenizer self.tokenizer = tokenizer self.metric = kwargs.get("metric", "acc") + self.train_config = train_config + self.model_config = model_config + def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -165,10 +202,10 @@ def forward(self, speech_mel = kwargs.get("speech_mel", None) speech_mask = kwargs.get("speech_mask", None) - speech_encoder_outs = None + encoder_outs = None if speech_mel is not None: - speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) - speech_encoder_outs = self.speech_encoder_projector(speech_encoder_outs) + encoder_outs = self.encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) # bs*seq*dim + encoder_outs = self.encoder_projector(encoder_outs) input_ids[input_ids == -1] = 0 # print(input_ids[0]) @@ -180,18 +217,17 @@ def forward(self, inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) batch_size, token_num, dims = inputs_embeds.shape - _, l, _ = speech_encoder_outs.shape - speech_encoder_outs_pad = F.pad(speech_encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0) - inputs_embeds = speech_encoder_outs_pad * speech_mask[:, :, None] + inputs_embeds * (~speech_mask[:, :, None]) + _, l, _ = encoder_outs.shape + encoder_outs_pad = F.pad(encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0) + inputs_embeds = encoder_outs_pad * speech_mask[:, :, None] + inputs_embeds * (~speech_mask[:, :, None]) model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) acc = -1 if self.metric: with torch.no_grad(): - logits = model_outputs.logits - batch_size, l, vocab_size = logits.size() - acc = compute_accuracy(logits.contiguous().view(-1, vocab_size), labels, ignore_label=-100) + preds = torch.argmax(model_outputs.logits, -1) + acc = compute_accuracy(preds.detach(), labels.detach(), ignore_label=-100) return model_outputs, acc @@ -217,8 +253,8 @@ def generate( # speech_raw = whisper.pad_or_trim(speech_raw) speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0)[None, :, :].to(device) - speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) - speech_encoder_outs = self.speech_encoder_projector(speech_encoder_outs) + encoder_outs = self.encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) + encoder_outs = self.encoder_projector(encoder_outs) prompt=""" Please provide an emotional response based on the emotional speech you hear. @@ -238,7 +274,7 @@ def generate( else: inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) - inputs_embeds = torch.cat((speech_encoder_outs, inputs_embeds[None, :, :]), dim=1) # [speech,prompt] + inputs_embeds = torch.cat((encoder_outs, inputs_embeds[None, :, :]), dim=1) # [speech,prompt] atts = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(inputs_embeds.device) diff --git a/src/llama_recipes/utils/compute_utils.py b/src/llama_recipes/utils/compute_utils.py new file mode 100644 index 00000000..14328b29 --- /dev/null +++ b/src/llama_recipes/utils/compute_utils.py @@ -0,0 +1,3 @@ + +def calculate_output_length_1d(L_in, kernel_size, stride, padding=0): + return (L_in + 2 * padding - kernel_size) // stride + 1 \ No newline at end of file diff --git a/src/llama_recipes/utils/metric.py b/src/llama_recipes/utils/metric.py index 9fd372b6..2de2129a 100644 --- a/src/llama_recipes/utils/metric.py +++ b/src/llama_recipes/utils/metric.py @@ -4,20 +4,17 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label): """Calculate accuracy. Args: - pad_outputs (Tensor): Prediction tensors (B * Lmax, D). - pad_targets (LongTensor): Target label tensors (B, Lmax, D). + pad_outputs (LongTensor): Prediction tensors (B, Lmax). + pad_targets (LongTensor): Target label tensors (B, Lmax). ignore_label (int): Ignore label id. Returns: float: Accuracy value (0.0 - 1.0). """ - pad_pred = pad_outputs.view( - pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) - ).argmax(2) mask = pad_targets != ignore_label numerator = torch.sum( - pad_pred.masked_select(mask) == pad_targets.masked_select(mask) + pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) ) denominator = torch.sum(mask) - return float(numerator) / float(denominator) \ No newline at end of file + return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type \ No newline at end of file diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index f75a2c2a..62aae5f0 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -26,6 +26,7 @@ ) from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper from llama_recipes.utils.memory_utils import MemoryTrace +from llama_recipes.utils.metric import compute_accuracy def set_tokenizer_params(tokenizer: LlamaTokenizer): @@ -65,8 +66,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche train_prep = [] train_loss = [] + train_acc = [] val_prep = [] val_loss =[] + val_acc = [] epoch_times = [] checkpoint_times = [] results = {} @@ -76,6 +79,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche with MemoryTrace() as memtrace: # track the memory usage model.train() total_loss = 0.0 + total_acc = 0.0 total_length = len(train_dataloader)//gradient_accumulation_steps pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) for step, batch in enumerate(train_dataloader): @@ -85,14 +89,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche else: batch[key] = batch[key].to('cuda:0') with autocast(): - outputs = model(**batch) - acc = -1 - if isinstance(outputs, tuple): - outputs, acc = outputs + outputs, *rest = model(**batch) + acc = rest[0] if rest else -1 loss = outputs.loss loss = loss / gradient_accumulation_steps + acc = acc / gradient_accumulation_steps total_loss += loss.detach().float() + total_acc += acc if train_config.use_fp16: # if fp16 is enabled, use gradient scaler to handle gradient update scaler.scale(loss).backward() @@ -117,13 +121,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche # Reducing total_loss across all devices if there's more than one CUDA device if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(total_acc, op=dist.ReduceOp.SUM) train_epoch_loss = total_loss / len(train_dataloader) + train_epoch_acc = total_acc / len(train_dataloader) if train_config.enable_fsdp: train_epoch_loss = train_epoch_loss/world_size + train_epoch_acc = train_epoch_acc/world_size train_perplexity = torch.exp(train_epoch_loss) train_prep.append(train_perplexity) train_loss.append(train_epoch_loss) + train_acc.append(train_epoch_acc) if train_config.enable_fsdp: if rank==0: @@ -143,7 +151,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche lr_scheduler.step() if train_config.run_validation: - eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) + eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) checkpoint_start_time = time.perf_counter() if train_config.save_model: if train_config.enable_fsdp: @@ -217,20 +225,26 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") else: print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") - val_loss.append(best_val_loss) + val_loss.append(eval_epoch_loss) val_prep.append(eval_ppl) + if rest: + val_acc.append(rest[0]) + else: + val_acc.append(-1) if train_config.run_test_during_validation: if train_config.enable_fsdp: if rank==0: print("=====================================") print(f"Test the file {train_config.run_test_during_validation_file} during validation:") - print(model.generate(train_config.run_test_during_validation_file)) + with autocast(): + print(model.generate(train_config.run_test_during_validation_file)) print("=====================================") dist.barrier() else: print("=====================================") print(f"Test the file {train_config.run_test_during_validation_file} during validation:") - print(model.generate(train_config.run_test_during_validation_file)) + with autocast(): + print(model.generate(train_config.run_test_during_validation_file)) print("=====================================") if train_config.enable_fsdp: if rank==0: @@ -241,15 +255,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0 avg_train_prep = sum(train_prep)/len(train_prep) avg_train_loss = sum(train_loss)/len(train_loss) + avg_train_acc = sum(train_acc)/len(train_acc) if train_config.run_validation: avg_eval_prep = sum(val_prep)/len(val_prep) avg_eval_loss = sum(val_loss)/len(val_loss) + avg_eval_acc = sum(val_acc)/len(val_acc) results['avg_train_prep'] = avg_train_prep results['avg_train_loss'] = avg_train_loss + results['avg_train_acc'] = avg_train_acc if train_config.run_validation: results['avg_eval_prep'] = avg_eval_prep results['avg_eval_loss'] = avg_eval_loss + results['avg_eval_acc'] = avg_eval_acc results["avg_epoch_time"] = avg_epoch_time results["avg_checkpoint_time"] = avg_checkpoint_time @@ -276,6 +294,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): model.eval() eval_preds = [] eval_loss = 0.0 # Initialize evaluation loss + eval_acc = 0.0 + autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext # (Fix:MZY): fix expected scalar type mismatch in norm + with MemoryTrace() as memtrace: for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): for key in batch.keys(): @@ -286,13 +307,13 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): # Ensure no gradients are computed for this scope to save memory with torch.no_grad(): # Forward pass and compute loss - outputs = model(**batch) - acc = -1 - if isinstance(outputs, tuple): - outputs, acc = outputs + with autocast(): # (Fix:MZY): fix expected scalar type mismatch in norm + outputs, *rest = model(**batch) + acc = rest[0] if rest else -1 loss = outputs.loss eval_loss += loss.detach().float() + eval_acc += acc # Decode predictions and add to evaluation predictions list preds = torch.argmax(outputs.logits, -1) eval_preds.extend( @@ -302,21 +323,24 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): # If there's more than one CUDA device, reduce evaluation loss across all devices if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM) # Compute average loss and perplexity eval_epoch_loss = eval_loss / len(eval_dataloader) + eval_epoch_acc = eval_acc / len(eval_dataloader) if train_config.enable_fsdp: eval_epoch_loss = eval_epoch_loss/world_size + eval_epoch_acc = eval_epoch_acc/world_size eval_ppl = torch.exp(eval_epoch_loss) # Print evaluation metrics if train_config.enable_fsdp: if local_rank==0: - print(f" {eval_ppl=} {eval_epoch_loss=}") + print(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") else: - print(f" {eval_ppl=} {eval_epoch_loss=}") + print(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") - return eval_ppl, eval_epoch_loss + return eval_ppl, eval_epoch_loss, eval_epoch_acc def freeze_transformer_layers(model, num_layer): for i, layer in enumerate(model.model.layers):