Skip to content

Commit

Permalink
Merge pull request #4 from ddlBoJack/main
Browse files Browse the repository at this point in the history
update with main
  • Loading branch information
LauraGPT authored Dec 4, 2023
2 parents 1da9c02 + 3d6216e commit bdd527d
Show file tree
Hide file tree
Showing 13 changed files with 234 additions and 87 deletions.
59 changes: 44 additions & 15 deletions scripts/finetune_echat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1
export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1
# export OMP_NUM_THREADS=1
# export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128

# debug setting for multiple gpus
# export NCCL_DEBUG=INFO
Expand All @@ -11,57 +12,85 @@ export OMP_NUM_THREADS=1

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt
# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-proj2048-debug

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
python src/llama_recipes/pipeline/finetune.py \
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
--model_name echat \
--use_peft --peft_method lora \
--freeze_encoder \
--freeze_llm \
--use_fp16 \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \
--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \
--batching_strategy custom \
--custom_dataset.max_words 1024 \
--num_epochs 100 \
--batch_size_training 2 \
--val_batch_size 2 \
--output_dir $output_dir \
--run_test_during_validation true \
--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M013.wav \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
--run_test_during_validation \
--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_M040.wav \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1"
# --use_peft --peft_method lora \

# train
# {"trans": "Well, do you have your passport?\n",
# "emotion": "xxx",
# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_F009.wav"}
# {"trans": "No, I don't have a passport.\n",
# "emotion": "neu",
# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M010.wav"}

# val
# {"trans": "Yeah, well thanks for your help.\n",
# "emotion": "ang",
# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_M040.wav"}
# {"trans": "I'm sorry. Good luck, man.\n",
# "emotion": "xxx",
# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_F038.wav"}

else
torchrun \
--nnodes 1 \
--nproc_per_node 2 \
src/llama_recipes/pipeline/finetune.py \
--model_name echat \
--freeze_encoder \
--use_fp16 \
--enable_fsdp \
--use_peft --peft_method lora \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \
--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \
--batching_strategy custom \
--custom_dataset.max_words 1024 \
--num_epochs 100 \
--batch_size_training 8 \
--val_batch_size 8 \
--batch_size_training 2 \
--val_batch_size 2 \
--output_dir $output_dir \
--run_test_during_validation \
--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M013.wav \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
fi
--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_M040.wav \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1"
# --freeze_llm \
fi
14 changes: 9 additions & 5 deletions scripts/inference_echat.sh
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=1
export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt
# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
#python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
python src/llama_recipes/pipeline/inference.py \
--model_name echat \
--use_peft --peft_method lora \
--freeze_llm \
--use_fp16 \
--quantization \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
Expand All @@ -28,5 +30,7 @@ python src/llama_recipes/pipeline/inference.py \
--num_epochs 1 \
--batch_size_training 2 \
--output_dir $output_dir \
--ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/0/model.pt" \
--peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/0"
--ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \
--wav_path "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_F035.wav"
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1"
# --use_peft --peft_method lora \
3 changes: 2 additions & 1 deletion src/llama_recipes/configs/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ class custom_dataset:
train_split: str = "train"
test_split: str = "validation"
data_path: str = NotImplemented
max_words: int = NotImplemented
max_words: int = NotImplemented
max_mel: int = 1000
4 changes: 3 additions & 1 deletion src/llama_recipes/configs/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ class model_config:
llm_name: str = "llama-2-7b-hf"
llm_path: str = "PATH/to/LLAMA/7B"
encoder_name: str = None
encoder_ds_rate: int = 2
encoder_path: str = None
encoder_projector: str = "linear"
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
2 changes: 2 additions & 0 deletions src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ class train_config:
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation: bool = False
run_test_during_validation_file: str = "test.wav"
freeze_llm: bool = False
freeze_encoder: bool = False
46 changes: 32 additions & 14 deletions src/llama_recipes/datasets/echat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torchaudio
from torch.utils.data import Dataset
import whisper
from llama_recipes.utils.compute_utils import calculate_output_length_1d


class EChatDataset(Dataset):
Expand All @@ -23,36 +24,52 @@ def __init__(
super().__init__()

self.dataset_config = dataset_config
self.max_words = dataset_config.max_words
self.max_mel = dataset_config.max_mel
self.tokenizer = tokenizer
self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
self.prompt_template = "USER: {}\n ASSISTANT:"
self.answer_template = "<|{}|><|{}|>"

with open(dataset_config.data_path, 'r') as file:
data = file.readlines()

sentence_list = []

for item in data:
dialog_name, dialog = item.split('\t', 1)
dialog_list = eval(dialog)
for sentence_id in range(len(dialog_list)-2):
if 'emotion' in dialog_list[sentence_id].keys() and 'emotion' in dialog_list[sentence_id+1].keys():
if dialog_list[sentence_id+1]['emotion'] != 'xxx':
sentence_dict = {}
sentence_dict['pre_wav'] = dialog_list[sentence_id]['wav']
sentence_dict['post_emotion'] = dialog_list[sentence_id+1]['emotion']
sentence_dict['post_trans'] = dialog_list[sentence_id+1]['trans']
sentence_list.append(sentence_dict)

total_sentence = len(sentence_list)
print(f"Using {total_sentence} sentence totally.")
# if split == "train":
# self.data = sentence_list[:int(total_sentence * 0.9)]
# else:
# self.data = sentence_list[int(total_sentence * 0.9):]

# debug
if split == "train":
self.data = data[:60]
self.data = sentence_list[:8]
else:
self.data = data[60:]
self.data = sentence_list[8:16]


def __len__(self) -> int:
return len(self.data)

def __getitem__(self, index):
item = self.data[index]
dialog_name, dialog = item.split('\t', 1)
dialog_list = eval(dialog)

while True:
sentence_id = random.randint(0, len(dialog_list)-2)
if 'emotion' in dialog_list[sentence_id].keys() and 'emotion' in dialog_list[sentence_id+1].keys():
if dialog_list[sentence_id]['emotion'] != 'xxx' and dialog_list[sentence_id+1]['emotion'] != 'xxx':
break
speech_raw = whisper.load_audio(dialog_list[sentence_id]['wav'])

speech_raw = whisper.load_audio(item['pre_wav'])
# speech_raw = whisper.pad_or_trim(speech_raw)
speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0)
speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0)[:self.max_mel]

prompt="""
Please provide an emotional response based on the emotional speech you hear.
Expand All @@ -65,12 +82,13 @@ def __getitem__(self, index):
"""

prompt = self.prompt_template.format(prompt)
answer = self.answer_template.format(dialog_list[sentence_id+1]['emotion'], dialog_list[sentence_id+1]['trans'])
answer = self.answer_template.format(item['post_emotion'], item['post_trans'])

prompt_ids = self.tokenizer.encode(prompt)

prompt_length = len(prompt_ids)
speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
speech_length = calculate_output_length_1d(speech_length, 5, 5) # ad-hoc for 5x cov1d downsample
speech_pseudo = torch.full((speech_length,),-1)

example = prompt + answer #FIX(MZY): avoid putting a bos token before answer.
Expand Down
3 changes: 2 additions & 1 deletion src/llama_recipes/datasets/speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self,
if split == "train":
self.data_list = contents[:-1000]
else:
self.data_list = contents[1000:]
self.data_list = contents[-1000:]

def get_source_len(self, data_dict):
return data_dict["source_len"]
Expand Down Expand Up @@ -76,6 +76,7 @@ def __getitem__(self, index):

prompt_length = len(prompt_ids)
speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
speech_length = calculate_output_length_1d(speech_length, 5, 5) # ad-hoc for 5x cov1d downsample
speech_pseudo = torch.full((speech_length,), -1)

example = prompt + answer # FIX(MZY): avoid putting a bos token before answer.
Expand Down
9 changes: 7 additions & 2 deletions src/llama_recipes/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,18 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
print(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch))
os.makedirs(save_dir, exist_ok=True)
model.llm.save_pretrained(save_dir)
if not cfg.freeze_llm:
model.llm.save_pretrained(save_dir)

save_full_path = os.path.join(save_dir, "model.pt")
cpu_state = model.state_dict()
project_dict = {}
if not cfg.freeze_encoder:
for key in cpu_state.keys():
if key.startswith("encoder."):
project_dict[key] = cpu_state[key]
for key in cpu_state.keys():
if "_projector" in key:
if key.startswith("encoder_projector."):
project_dict[key] = cpu_state[key]
torch.save(project_dict, save_full_path)

Expand Down
Loading

0 comments on commit bdd527d

Please sign in to comment.