diff --git a/.gitignore b/.gitignore index 0c04a67d..c72a47be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .DS_Store __pycache__ .ipynb_checkpoints +.idea/* +transformers diff --git a/scripts/finetune.sh b/scripts/finetune.sh index 9bda13f4..deb3b184 100644 --- a/scripts/finetune.sh +++ b/scripts/finetune.sh @@ -1,17 +1,17 @@ #!/bin/bash -# export PYTHONPATH=/root/whisper:$PYTHONPATH +#export PYTHONPATH=/root/whisper:$PYTHONPATH export CUDA_VISIBLE_DEVICES=0 export CUDA_LAUNCH_BLOCKING=1 cd /root/SLAM-LLM -audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth -speech_encoder_path=/home/oss/maziyang.mzy/models/Whisper/base.pt -llm_path=/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/zhifu.gzf/models/llama-2-hf-finetune # -m debugpy --listen 5678 --wait-for-client -python src/llama_recipes/pipeline/finetune.py \ +#python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +python src/llama_recipes/pipeline/finetune.py \ --model_name echat \ --use_peft --peft_method lora \ --quantization \ @@ -27,4 +27,6 @@ python src/llama_recipes/pipeline/finetune.py \ --custom_dataset.max_words 1024 \ --num_epochs 1 \ --batch_size_training 2 \ ---output_dir $output_dir \ No newline at end of file +--output_dir $output_dir \ +--ckpt_path "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt" \ +--peft_ckpt "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0" \ No newline at end of file diff --git a/scripts/inference.sh b/scripts/inference.sh new file mode 100644 index 00000000..92973b46 --- /dev/null +++ b/scripts/inference.sh @@ -0,0 +1,32 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export CUDA_LAUNCH_BLOCKING=1 + +cd /root/SLAM-LLM + +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/zhifu.gzf/models/llama-2-hf-finetune + +# -m debugpy --listen 5678 --wait-for-client +#python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +python src/llama_recipes/pipeline/inference.py \ +--model_name echat \ +--use_peft --peft_method lora \ +--quantization \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--dataset custom_dataset \ +--custom_dataset.file src/llama_recipes/datasets/speech_text_dataset.py:get_audio_dataset \ +--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \ +--batching_strategy custom \ +--custom_dataset.max_words 1024 \ +--num_epochs 1 \ +--batch_size_training 2 \ +--output_dir $output_dir \ +--ckpt_path "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt" \ +--peft_ckpt "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0" \ No newline at end of file diff --git a/src/llama_recipes/datasets/echat_dataset.py b/src/llama_recipes/datasets/echat_dataset.py index 03fe00a6..08b2383e 100644 --- a/src/llama_recipes/datasets/echat_dataset.py +++ b/src/llama_recipes/datasets/echat_dataset.py @@ -66,13 +66,15 @@ def __getitem__(self, index): prompt = self.prompt_template.format(prompt) answer = self.answer_template.format(dialog_list[sentence_id+1]['emotion'], dialog_list[sentence_id+1]['trans']) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats speech_pseudo = torch.full((speech_length,),-1) - example = prompt + answer - example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids = self.tokenizer.encode(answer) # FIX(GZF): [answer] + example_ids = prompt_ids + example_ids example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] example_ids = torch.tensor( example_ids, dtype=torch.int64 @@ -80,8 +82,9 @@ def __getitem__(self, index): example_ids = torch.cat((speech_pseudo, example_ids)) # [speech,prompt,answer,eos] labels_ids = copy.deepcopy(example_ids) # [speech,prompt,answer,eos] - labels_ids[:speech_length + prompt_length] = -1 #[-1,-1,answer,eos] - example_mask = example_ids.ge(-1) #[True,True,True,True] + labels_ids[:speech_length + prompt_length] = -1 #[-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) #FIX(GZF): [True,True,True,True] + label_mask = labels_ids.ge(0) #[False,False,True,True] example_ids[~example_mask] = 0 #[speech,prompt,answer,eos] labels_ids[~label_mask] = self.IGNORE_INDEX #[-100,answer,eos,-100] @@ -91,7 +94,8 @@ def __getitem__(self, index): "labels": labels_ids, "attention_mask": example_mask, 'speech_mel': speech_mel, - 'speech_length': speech_length + 'speech_length': speech_length, + } @@ -139,16 +143,16 @@ def collator(self, samples): for s in samples]) labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) for s in samples]) - attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) for s in samples]) - + speech_mel_max_length = max([s['speech_mel'].shape[0] for s in samples]) speech_mel = torch.stack([self.pad(s['speech_mel'], speech_mel_max_length, 0) for s in samples]) speech_mask = torch.zeros_like(attention_mask) for line, sample in enumerate(samples): - speech_mask[line, :sample['speech_length']] = 1 + speech_mask[line, :sample['speech_length']] = 1 #FIX(GZF): sample['speech_length']+1 return { 'input_ids': input_ids, diff --git a/src/llama_recipes/model_checkpointing/__init__.py b/src/llama_recipes/model_checkpointing/__init__.py index 9474f78c..76781b02 100644 --- a/src/llama_recipes/model_checkpointing/__init__.py +++ b/src/llama_recipes/model_checkpointing/__init__.py @@ -8,5 +8,6 @@ save_optimizer_checkpoint, save_model_and_optimizer_sharded, load_model_sharded, - load_sharded_model_single_gpu + load_sharded_model_single_gpu, + save_model_checkpoint_peft, ) diff --git a/src/llama_recipes/model_checkpointing/checkpoint_handler.py b/src/llama_recipes/model_checkpointing/checkpoint_handler.py index b097df97..e6367bf2 100644 --- a/src/llama_recipes/model_checkpointing/checkpoint_handler.py +++ b/src/llama_recipes/model_checkpointing/checkpoint_handler.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - +import os from pathlib import Path from datetime import datetime import torch @@ -160,7 +160,22 @@ def save_model_checkpoint( print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") - +def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): + save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch)) + os.makedirs(save_dir, exist_ok=True) + model.llm.save_pretrained(save_dir) + + save_full_path = os.path.join(save_dir, "model.pt") + cpu_state = model.state_dict() + project_dict = {} + for key in cpu_state.keys(): + if "_projector" in key: + project_dict[key] = cpu_state[key] + torch.save(project_dict, save_full_path) + + print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + + def load_model_checkpoint(model, rank, cfg): """load local checkpoint to rank0 cpu diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index bf72ccc3..37b65026 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -1,3 +1,4 @@ +import os import types import torch import soundfile as sf @@ -14,7 +15,8 @@ from llama_recipes.utils.config_utils import generate_peft_config from llama_recipes.utils.train_utils import print_model_size - +from peft import PeftModel, PeftConfig +from torch.nn import CrossEntropyLoss def setup_model(tokenizer, train_config, model_config, **kwargs): return slam_model(tokenizer, train_config, model_config, **kwargs) @@ -104,6 +106,10 @@ def setup_llm(train_config, model_config, **kwargs): peft_config = generate_peft_config(train_config, kwargs) model = get_peft_model(model, peft_config) model.print_trainable_parameters() + + if kwargs.get("peft_ckpt", None): + print("loading ckpt from: ", kwargs.get("peft_ckpt")) + model = PeftModel.from_pretrained(model, kwargs.get("peft_ckpt")) return model @@ -128,8 +134,13 @@ def __init__( self.llm = setup_llm(train_config, model_config, **kwargs) # projector - self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0] ,self.llm.config.hidden_size) - + self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0], self.llm.config.hidden_size) + ckpt_path = kwargs.get("ckpt_path", None) + # ckpt_path = kwargs.get("ckpt_path", "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt") + if ckpt_path is not None: + print("loading ckpt from: ", ckpt_path) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + self.load_state_dict(ckpt_dict, strict=False) # tokenizer self.tokenizer = tokenizer @@ -152,18 +163,90 @@ def forward(self, speech_encoder_outs = None if speech_mel is not None: speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) - speech_encoder_outs = self.speech_encoder_projector(speech_encoder_outs) + speech_encoder_outs = self.speech_encoder_projector.to(speech_encoder_outs.device)(speech_encoder_outs) input_ids[input_ids == -1] = 0 if hasattr(self.llm.model, "embed_tokens"): inputs_embeds = self.llm.model.embed_tokens(input_ids) - else: + elif hasattr(self.llm.model.model, "embed_tokens"): inputs_embeds = self.llm.model.model.embed_tokens(input_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) + batch_size, token_num, dims = inputs_embeds.shape _, l, _ = speech_encoder_outs.shape speech_encoder_outs_pad = F.pad(speech_encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0) inputs_embeds = speech_encoder_outs_pad * speech_mask[:, :, None] + inputs_embeds * (~speech_mask[:, :, None]) + + model_outputs = self.llm.to(speech_encoder_outs.device)(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) + + return model_outputs + + @torch.no_grad() + def generate( + self, + wav_path = None, + generation_config = None, + logits_processor = None, + stopping_criteria = None, + prefix_allowed_tokens_fn = None, + synced_gpus = None, + assistant_model = None, + streamer = None, + negative_prompt_ids = None, + negative_prompt_attention_mask = None, + **kwargs, + ): + + device = kwargs.get("device", "cuda") + assert os.path.exists(wav_path) + speech_raw = whisper.load_audio(wav_path) + # speech_raw = whisper.pad_or_trim(speech_raw) + speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0)[None, :, :].to(device) + + speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) + speech_encoder_outs = self.speech_encoder_projector.to(speech_encoder_outs.device)(speech_encoder_outs) + + prompt=""" + Please provide an emotional response based on the emotional speech you hear. + Remember to format your answer as follows: <|EMOTION|><|REPLY|>. + <|EMOTION|> is a standalone adjective. + <|REPLY|> is a reply based on a the speech. + """ + prompt = "USER: {}\n ASSISTANT:".format(prompt) + prompt_ids = self.tokenizer.encode(prompt) # FIX(GZF) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat((speech_encoder_outs, inputs_embeds[None, :, :]), dim=1) # [speech,prompt] + + atts = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(inputs_embeds.device) + + # generate + output = self.llm.generate( + inputs_embeds=inputs_embeds, + max_length=kwargs.get("max_length", 200), + num_beams=kwargs.get("num_beams", 1), + do_sample=kwargs.get("do_sample", True), + min_length=kwargs.get("min_length", 1), + top_p=kwargs.get("top_p", 0.9), + repetition_penalty=kwargs.get("repetition_penalty", 1.0), + length_penalty=kwargs.get("length_penalty", 1.0), + temperature=kwargs.get("temperature", 1.0), + attention_mask=atts, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id + ) + + output_text = self.tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True) - model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) + return output_text - return model_outputs \ No newline at end of file diff --git a/src/llama_recipes/pipeline/inference.py b/src/llama_recipes/pipeline/inference.py new file mode 100644 index 00000000..385ed425 --- /dev/null +++ b/src/llama_recipes/pipeline/inference.py @@ -0,0 +1,42 @@ +import fire +import random +import torch +# import argparse +from llama_recipes.models.slam_model import slam_model +# config +from llama_recipes.configs import fsdp_config as FSDP_CONFIG +from llama_recipes.configs import train_config as TRAIN_CONFIG +from llama_recipes.configs import model_config as MODEL_CONFIG +from llama_recipes.utils.config_utils import ( + update_config, + generate_peft_config, + generate_dataset_config, + get_dataloader_kwargs, +) +from llama_recipes.pipeline.model_factory import model_factory + +def main(**kwargs): + + # Update the configuration for the training and sharding process + train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG() + update_config((train_config, fsdp_config, model_config), **kwargs) + + # Set the seeds for reproducibility + torch.cuda.manual_seed(train_config.seed) + torch.manual_seed(train_config.seed) + random.seed(train_config.seed) + + model, tokenizer = model_factory(train_config, model_config, **kwargs) + model.to(kwargs.get("device", "cuda")) + model.eval() + + print("=====================================") + # wav_path = input("Your Wav Path:\n") + # prompt = input("Your Prompt:\n") + wav_path = "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M001.wav" + print(model.generate(wav_path)) + + + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index f154f965..d1048af1 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -18,7 +18,7 @@ from transformers import LlamaTokenizer -from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint +from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_model_checkpoint_peft from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper from llama_recipes.utils.memory_utils import MemoryTrace @@ -144,7 +144,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print(f"we are about to save the PEFT modules") else: print(f"we are about to save the PEFT modules") - model.save_pretrained(train_config.output_dir) + # model.save_pretrained(train_config.output_dir) + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) if train_config.enable_fsdp: if rank==0: print(f"PEFT modules are saved in {train_config.output_dir} directory")