From 72e20591a1c59373a13a67e792ec1f2536b71294 Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Mon, 15 Jan 2024 13:25:10 +0800 Subject: [PATCH 1/3] update compute ppl --- scripts/compute_wer.sh | 4 +- scripts/finetune_asr_vicuna.sh | 14 +++---- scripts/inference_asr_batch.sh | 16 +++---- .../datasets/speech_dataset_inference.py | 4 +- src/llama_recipes/utils/compute_ppl.py | 42 +++++++++++++++++++ 5 files changed, 61 insertions(+), 19 deletions(-) create mode 100644 src/llama_recipes/utils/compute_ppl.py diff --git a/scripts/compute_wer.sh b/scripts/compute_wer.sh index 9ff89f71..c6035eda 100644 --- a/scripts/compute_wer.sh +++ b/scripts/compute_wer.sh @@ -1,7 +1,7 @@ #cd /root/SLAM-LLM -trans="/nfs/maziyang.mzy/exps/llama-2-chat-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240111/asr/3/decode_log_test_other_beam4_repetition_penalty1_gt" -preds="/nfs/maziyang.mzy/exps/llama-2-chat-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240111/asr/3/decode_log_test_other_beam4_repetition_penalty1_pred" +trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-padding30-20240113/asr/2/decode_log_test_other_beam4_repetition_penalty1_gt" +preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-padding30-20240113/asr/2/decode_log_test_other_beam4_repetition_penalty1_pred" # python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc # python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer diff --git a/scripts/finetune_asr_vicuna.sh b/scripts/finetune_asr_vicuna.sh index 15d4886e..6e3b52b2 100644 --- a/scripts/finetune_asr_vicuna.sh +++ b/scripts/finetune_asr_vicuna.sh @@ -1,7 +1,7 @@ #!/bin/bash # export PYTHONPATH=/root/whisper:$PYTHONPATH export PYTHONPATH=/root/fairseq:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=4,5,6,7 +export CUDA_VISIBLE_DEVICES=2,3,4,5 # export CUDA_LAUNCH_BLOCKING=1 export OMP_NUM_THREADS=1 @@ -12,13 +12,13 @@ export OMP_NUM_THREADS=1 cd /root/SLAM-LLM -speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt -# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 # llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5 -output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-paddingr-20240112 +output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-qwen-prompt-padding30-20240113 # -m debugpy --listen 5678 --wait-for-client if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then @@ -61,6 +61,7 @@ else torchrun \ --nnodes 1 \ --nproc_per_node 4 \ +--master_port=29502 \ src/llama_recipes/pipeline/finetune.py \ --model_name asr \ --freeze_encoder \ @@ -81,8 +82,8 @@ src/llama_recipes/pipeline/finetune.py \ --speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ --batching_strategy custom \ --num_epochs 100 \ ---batch_size_training 6 \ ---val_batch_size 6 \ +--batch_size_training 4 \ +--val_batch_size 4 \ --num_workers_dataloader 4 \ --lr 1e-4 \ --output_dir $output_dir \ @@ -97,7 +98,6 @@ src/llama_recipes/pipeline/finetune.py \ # --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \ # --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \ # --use_peft --peft_method lora \ -# --master_port=29501 \ fi # {"key": "1001-134707-0000_ASR", "prompt": "", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} diff --git a/scripts/inference_asr_batch.sh b/scripts/inference_asr_batch.sh index 385784e0..2a0efc8f 100644 --- a/scripts/inference_asr_batch.sh +++ b/scripts/inference_asr_batch.sh @@ -1,19 +1,19 @@ #!/bin/bash #export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=7 # export CUDA_LAUNCH_BLOCKING=1 cd /root/SLAM-LLM -speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt -# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt # llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf -llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf -# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 -output_dir=/nfs/maziyang.mzy/exps/llama-2-chat-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240111 -ckpt_path=$output_dir/asr/3 +output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-qwen-prompt-padding30-20240113 +ckpt_path=$output_dir/asr/2 # peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4 val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl decode_log=$ckpt_path/decode_log_test_other_beam4_repetition_penalty1 @@ -22,7 +22,7 @@ decode_log=$ckpt_path/decode_log_test_other_beam4_repetition_penalty1 python src/llama_recipes/pipeline/inference_batch.py \ --model_name asr \ --freeze_encoder \ ---llm_name llama-2-7b-chat-hf \ +--llm_name vicuna-7b-v1.5 \ --llm_path $llm_path \ --llm_dim 4096 \ --encoder_name whisper \ diff --git a/src/llama_recipes/datasets/speech_dataset_inference.py b/src/llama_recipes/datasets/speech_dataset_inference.py index 8ead2b23..2af2db72 100644 --- a/src/llama_recipes/datasets/speech_dataset_inference.py +++ b/src/llama_recipes/datasets/speech_dataset_inference.py @@ -29,7 +29,7 @@ def __init__(self, # self.data_list = contents self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss - self.prompt_template = "USER: {}\n ASSISTANT:" + self.prompt_template = "{}" self.fix_length_audio = dataset_config.fix_length_audio self.data_list = [] @@ -72,7 +72,7 @@ def __getitem__(self, index): # audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30] audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1, 0) - prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " + prompt = "" prompt = self.prompt_template.format(prompt) prompt_ids = self.tokenizer.encode(prompt) diff --git a/src/llama_recipes/utils/compute_ppl.py b/src/llama_recipes/utils/compute_ppl.py new file mode 100644 index 00000000..55baa292 --- /dev/null +++ b/src/llama_recipes/utils/compute_ppl.py @@ -0,0 +1,42 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from tqdm import tqdm +import json + +MODEL_PATH = "/nfs/maziyang.mzy/models/vicuna-7b-v1.5" +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) +model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) + +device = 'cuda:6' +model.to(device) +model.eval() + +corpus_path = "/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl" +corpus = [] +with open(corpus_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + corpus.append(data_dict.get("target", None)) + +cumulative_log_likelihood = 0 +total_tokens = 0 + +for sentence in tqdm(corpus): + inputs = tokenizer(sentence, return_tensors="pt").to(device) + + input_ids = inputs["input_ids"] + input_len = input_ids.size(1) + total_tokens += input_len + + with torch.no_grad(): + outputs = model(**inputs, labels=input_ids) + log_likelihood = outputs.loss * input_len + cumulative_log_likelihood += log_likelihood.item() + + +average_log_likelihood = cumulative_log_likelihood / total_tokens +corpus_ppl = torch.exp(torch.tensor(average_log_likelihood)).item() + +print(f"Model: {MODEL_PATH}") +print(f"Corpus: {corpus_path}") +print(f"Corpus Perplexity: {corpus_ppl}") From 248f86e53bd128e1982077e5071f59a2b1d1e4b4 Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Mon, 15 Jan 2024 13:45:42 +0800 Subject: [PATCH 2/3] fix keys named problem --- src/llama_recipes/datasets/audio_dataset.py | 6 +-- src/llama_recipes/datasets/avsr_dataset.py | 38 +++---------------- src/llama_recipes/datasets/speech_dataset.py | 6 +-- .../datasets/speech_dataset_inference.py | 6 +-- src/llama_recipes/models/slam_model.py | 27 ++++++------- 5 files changed, 26 insertions(+), 57 deletions(-) diff --git a/src/llama_recipes/datasets/audio_dataset.py b/src/llama_recipes/datasets/audio_dataset.py index e441b0fe..b4b0cf69 100644 --- a/src/llama_recipes/datasets/audio_dataset.py +++ b/src/llama_recipes/datasets/audio_dataset.py @@ -147,9 +147,9 @@ def collator(self, samples): audio_mel_mask = torch.zeros(len(samples), audio_mel_max_length) for line, sample in enumerate(samples): audio_mel_mask[line, :sample['audio_mel'].shape[0]] = 1 - audio_mask = torch.zeros_like(attention_mask) + modality_mask = torch.zeros_like(attention_mask) for line, sample in enumerate(samples): - audio_mask[line, :sample['audio_length']] = 1 + modality_mask[line, :sample['audio_length']] = 1 return { 'input_ids': input_ids, @@ -157,7 +157,7 @@ def collator(self, samples): 'attention_mask': attention_mask, 'audio_mel': audio_mel, 'audio_mel_mask': audio_mel_mask, - 'audio_mask': audio_mask + 'modality_mask': modality_mask } diff --git a/src/llama_recipes/datasets/avsr_dataset.py b/src/llama_recipes/datasets/avsr_dataset.py index c16b5a77..a00f44cb 100644 --- a/src/llama_recipes/datasets/avsr_dataset.py +++ b/src/llama_recipes/datasets/avsr_dataset.py @@ -252,9 +252,9 @@ def collator(self, samples): attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) for s in samples]) - audio_mask = torch.zeros_like(attention_mask) + modality_mask = torch.zeros_like(attention_mask) for line, sample in enumerate(samples): - audio_mask[line, :sample['audio_length']] = 1 #downsample 再/5 + modality_mask[line, :sample['audio_length']] = 1 #downsample 再/5 # audio & mask if not self.modal == "VO": @@ -279,13 +279,6 @@ def collator(self, samples): vis_len = None inputBatch = (aud_seq_list, aud_padding_mask, vis_seq_list, vis_len) #!!! - - # targetinBatch = pad_sequence([data[1] for data in dataBatch], batch_first=True) - # targetoutBatch = pad_sequence([data[2] for data in dataBatch], batch_first=True) - # targetLenBatch = torch.stack([data[3] for data in dataBatch]) - targetinBatch = pad_sequence([data['trgtin'] for data in samples], batch_first=True) - targetoutBatch = pad_sequence([data['trgtout'] for data in samples], batch_first=True) - targetLenBatch = torch.stack([data['trgtLen'] for data in samples]) if self.modal == "AO": inputBatch = (inputBatch[0].float(), inputBatch[1], None, None) @@ -293,40 +286,19 @@ def collator(self, samples): inputBatch = (None, None, inputBatch[2].float(), inputBatch[3].int()) else: inputBatch = (inputBatch[0].float(), inputBatch[1], inputBatch[2].float(), inputBatch[3].int()) - - targetinBatch = targetinBatch.int() - targetoutBatch = targetoutBatch.int() - targetLenBatch = targetLenBatch.int() - targetMask = torch.zeros_like(targetoutBatch, device=targetoutBatch.device) - targetMask[(torch.arange(targetMask.shape[0]), targetLenBatch.long() - 1)] = 1 - targetMask = (1 - targetMask.flip([-1]).cumsum(-1).flip([-1])).bool() - - # return { - # "inputBatch0": inputBatch[0], - # "inputBatch1": inputBatch[1], - # "inputBatch2": inputBatch[2], - # "inputBatch3": inputBatch[3], - - # "targetoutBatch": targetoutBatch, - # "targetLenBatch": targetLenBatch.long(), - # 'maskw2v': True, - # } + return { 'input_ids': input_ids, #torch.Size([4, 114]) 'labels': labels, #torch.Size([4, 114]) 'attention_mask': attention_mask, #torch.Size([4, 114]) # 'audio_mel': audio_mel, # 'audio_mel_post_mask': audio_mel_post_mask, - 'audio_mask': audio_mask, + 'modality_mask': modality_mask, "audio": inputBatch[0], #torch.Size([4, 92800]) - "audiomask": inputBatch[1], #torch.Size([4, 92800]) + "audio_mask": inputBatch[1], #torch.Size([4, 92800]) "visual": inputBatch[2], #torch.Size([4, 146, 1, 112, 112]) "vis_len": inputBatch[3], #torch.Size([4]) - - "targetoutBatch": targetoutBatch, #torch.Size([4, 50]) - "targetLenBatch": targetLenBatch.long(), #torch.Size([4]) - 'maskw2v': True, } def pad(self, sequence, max_length, padding_idx=0): diff --git a/src/llama_recipes/datasets/speech_dataset.py b/src/llama_recipes/datasets/speech_dataset.py index 18a5dd05..5a51d330 100644 --- a/src/llama_recipes/datasets/speech_dataset.py +++ b/src/llama_recipes/datasets/speech_dataset.py @@ -149,9 +149,9 @@ def collator(self, samples): for line, sample in enumerate(samples): audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 - audio_mask = torch.zeros_like(attention_mask) + modality_mask = torch.zeros_like(attention_mask) for line, sample in enumerate(samples): - audio_mask[line, :sample['audio_length']] = 1 + modality_mask[line, :sample['audio_length']] = 1 return { 'input_ids': input_ids, @@ -159,7 +159,7 @@ def collator(self, samples): 'attention_mask': attention_mask, 'audio_mel': audio_mel, 'audio_mel_post_mask': audio_mel_post_mask, - 'audio_mask': audio_mask + 'modality_mask': modality_mask } diff --git a/src/llama_recipes/datasets/speech_dataset_inference.py b/src/llama_recipes/datasets/speech_dataset_inference.py index 2af2db72..98675c98 100644 --- a/src/llama_recipes/datasets/speech_dataset_inference.py +++ b/src/llama_recipes/datasets/speech_dataset_inference.py @@ -127,9 +127,9 @@ def collator(self, samples): for line, sample in enumerate(samples): audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 - audio_mask = torch.zeros_like(attention_mask) + modality_mask = torch.zeros_like(attention_mask) for line, sample in enumerate(samples): - audio_mask[line, :sample['audio_length']] = 1 + modality_mask[line, :sample['audio_length']] = 1 keys = [s['key'] for s in samples] targets = [s['target'] for s in samples] @@ -138,7 +138,7 @@ def collator(self, samples): 'attention_mask': attention_mask, 'audio_mel': audio_mel, 'audio_mel_post_mask': audio_mel_post_mask, - 'audio_mask': audio_mask, + 'modality_mask': modality_mask, 'keys': keys, 'targets': targets } diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index 98a04689..bf661fb5 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -185,16 +185,13 @@ def forward(self, audio_mel = kwargs.get("audio_mel", None) audio_mel_mask = kwargs.get("audio_mel_mask", None) audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper - audio_mask = kwargs.get("audio_mask", None) - - audio = kwargs.get("audio", None) #torch.Size([2, 96480]) - audiomask = kwargs.get("audiomask", None) #删 #torch.Size([2, 96480]) - visual = kwargs.get("visual", None) #torch.Size([2, 151, 1, 112, 112]) - vis_len = kwargs.get("vis_len", None) #tensor([ 77, 151], device='cuda:0', dtype=torch.int32) - maskw2v = kwargs.get("maskw2v", None) #True - targetoutBatch = kwargs.get("targetoutBatch", None) #torch.Size([2, 29]) - targetLenBatch = kwargs.get("targetLenBatch", None) #tensor([18, 29], device='cuda:0') + modality_mask = kwargs.get("modality_mask", None) + audio = kwargs.get("audio", None) + audio_mask = kwargs.get("audio_mask", None) + visual = kwargs.get("visual", None) + vis_len = kwargs.get("vis_len", None) + maskw2v = kwargs.get("maskw2v", False) #(FIX:MZY) False for supervised learning and inference encoder_outs = None @@ -204,7 +201,7 @@ def forward(self, if self.model_config.encoder_name == "beats": encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim if self.model_config.encoder_name == "moco_wav2vec2": - encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audiomask, visual, vis_len) ,maskw2v) # bs*seq*dim + encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audio_mask, visual, vis_len) ,maskw2v) # bs*seq*dim if self.model_config.encoder_projector == "q-former": encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) @@ -220,11 +217,11 @@ def forward(self, else: inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) - if audio_mask is not None: + if modality_mask is not None: batch_size, token_num, dims = inputs_embeds.shape _, l, _ = encoder_outs.shape encoder_outs_pad = F.pad(encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0) - inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None]) + inputs_embeds = encoder_outs_pad * modality_mask[:, :, None] + inputs_embeds * (~modality_mask[:, :, None]) model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) @@ -253,7 +250,7 @@ def generate(self, audio_mel = kwargs.get("audio_mel", None) audio_mel_mask = kwargs.get("audio_mel_mask", None) audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper - audio_mask = kwargs.get("audio_mask", None) + modality_mask = kwargs.get("modality_mask", None) encoder_outs = None if audio_mel is not None: @@ -276,11 +273,11 @@ def generate(self, else: inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) - if audio_mask is not None: + if modality_mask is not None: batch_size, token_num, dims = inputs_embeds.shape _, l, _ = encoder_outs.shape encoder_outs_pad = F.pad(encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0) - inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None]) + inputs_embeds = encoder_outs_pad * modality_mask[:, :, None] + inputs_embeds * (~modality_mask[:, :, None]) model_outputs = self.llm.generate( inputs_embeds=inputs_embeds, From 899352c37663eef42278357a6da971b4a4880d4c Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Mon, 15 Jan 2024 13:51:36 +0800 Subject: [PATCH 3/3] update --- src/llama_recipes/datasets/speech_dataset_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama_recipes/datasets/speech_dataset_inference.py b/src/llama_recipes/datasets/speech_dataset_inference.py index 98675c98..1a1155dc 100644 --- a/src/llama_recipes/datasets/speech_dataset_inference.py +++ b/src/llama_recipes/datasets/speech_dataset_inference.py @@ -29,7 +29,7 @@ def __init__(self, # self.data_list = contents self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss - self.prompt_template = "{}" + self.prompt_template = "USER: {}\n ASSISTANT:" self.fix_length_audio = dataset_config.fix_length_audio self.data_list = [] @@ -72,7 +72,7 @@ def __getitem__(self, index): # audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30] audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1, 0) - prompt = "" + prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " prompt = self.prompt_template.format(prompt) prompt_ids = self.tokenizer.encode(prompt)