Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update compute ppl, fix keys named problem #28

Merged
merged 5 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/compute_wer.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#cd /root/SLAM-LLM

trans="/nfs/maziyang.mzy/exps/llama-2-chat-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240111/asr/3/decode_log_test_other_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/llama-2-chat-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240111/asr/3/decode_log_test_other_beam4_repetition_penalty1_pred"
trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-padding30-20240113/asr/2/decode_log_test_other_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-padding30-20240113/asr/2/decode_log_test_other_beam4_repetition_penalty1_pred"

# python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc
# python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer
Expand Down
14 changes: 7 additions & 7 deletions scripts/finetune_asr_vicuna.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
export PYTHONPATH=/root/fairseq:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_VISIBLE_DEVICES=2,3,4,5
# export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

Expand All @@ -12,13 +12,13 @@ export OMP_NUM_THREADS=1

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt

llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
# llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5

output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-paddingr-20240112
output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-qwen-prompt-padding30-20240113

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
Expand Down Expand Up @@ -61,6 +61,7 @@ else
torchrun \
--nnodes 1 \
--nproc_per_node 4 \
--master_port=29502 \
src/llama_recipes/pipeline/finetune.py \
--model_name asr \
--freeze_encoder \
Expand All @@ -81,8 +82,8 @@ src/llama_recipes/pipeline/finetune.py \
--speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 6 \
--val_batch_size 6 \
--batch_size_training 4 \
--val_batch_size 4 \
--num_workers_dataloader 4 \
--lr 1e-4 \
--output_dir $output_dir \
Expand All @@ -97,7 +98,6 @@ src/llama_recipes/pipeline/finetune.py \
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \
# --use_peft --peft_method lora \
# --master_port=29501 \
fi

# {"key": "1001-134707-0000_ASR", "prompt": "<ASR>", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": "<ASR>"}
Expand Down
16 changes: 8 additions & 8 deletions scripts/inference_asr_batch.sh
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=7
# export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt

# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5

output_dir=/nfs/maziyang.mzy/exps/llama-2-chat-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240111
ckpt_path=$output_dir/asr/3
output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-qwen-prompt-padding30-20240113
ckpt_path=$output_dir/asr/2
# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl
decode_log=$ckpt_path/decode_log_test_other_beam4_repetition_penalty1
Expand All @@ -22,7 +22,7 @@ decode_log=$ckpt_path/decode_log_test_other_beam4_repetition_penalty1
python src/llama_recipes/pipeline/inference_batch.py \
--model_name asr \
--freeze_encoder \
--llm_name llama-2-7b-chat-hf \
--llm_name vicuna-7b-v1.5 \
--llm_path $llm_path \
--llm_dim 4096 \
--encoder_name whisper \
Expand Down
6 changes: 3 additions & 3 deletions src/llama_recipes/datasets/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,17 @@ def collator(self, samples):
audio_mel_mask = torch.zeros(len(samples), audio_mel_max_length)
for line, sample in enumerate(samples):
audio_mel_mask[line, :sample['audio_mel'].shape[0]] = 1
audio_mask = torch.zeros_like(attention_mask)
modality_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
audio_mask[line, :sample['audio_length']] = 1
modality_mask[line, :sample['audio_length']] = 1

return {
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'audio_mel': audio_mel,
'audio_mel_mask': audio_mel_mask,
'audio_mask': audio_mask
'modality_mask': modality_mask
}


Expand Down
38 changes: 5 additions & 33 deletions src/llama_recipes/datasets/avsr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ def collator(self, samples):
attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False)
for s in samples])

audio_mask = torch.zeros_like(attention_mask)
modality_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
audio_mask[line, :sample['audio_length']] = 1 #downsample 再/5
modality_mask[line, :sample['audio_length']] = 1 #downsample 再/5

# audio & mask
if not self.modal == "VO":
Expand All @@ -279,54 +279,26 @@ def collator(self, samples):
vis_len = None

inputBatch = (aud_seq_list, aud_padding_mask, vis_seq_list, vis_len) #!!!

# targetinBatch = pad_sequence([data[1] for data in dataBatch], batch_first=True)
# targetoutBatch = pad_sequence([data[2] for data in dataBatch], batch_first=True)
# targetLenBatch = torch.stack([data[3] for data in dataBatch])
targetinBatch = pad_sequence([data['trgtin'] for data in samples], batch_first=True)
targetoutBatch = pad_sequence([data['trgtout'] for data in samples], batch_first=True)
targetLenBatch = torch.stack([data['trgtLen'] for data in samples])

if self.modal == "AO":
inputBatch = (inputBatch[0].float(), inputBatch[1], None, None)
elif self.modal == "VO":
inputBatch = (None, None, inputBatch[2].float(), inputBatch[3].int())
else:
inputBatch = (inputBatch[0].float(), inputBatch[1], inputBatch[2].float(), inputBatch[3].int())

targetinBatch = targetinBatch.int()
targetoutBatch = targetoutBatch.int()
targetLenBatch = targetLenBatch.int()
targetMask = torch.zeros_like(targetoutBatch, device=targetoutBatch.device)
targetMask[(torch.arange(targetMask.shape[0]), targetLenBatch.long() - 1)] = 1
targetMask = (1 - targetMask.flip([-1]).cumsum(-1).flip([-1])).bool()

# return {
# "inputBatch0": inputBatch[0],
# "inputBatch1": inputBatch[1],
# "inputBatch2": inputBatch[2],
# "inputBatch3": inputBatch[3],

# "targetoutBatch": targetoutBatch,
# "targetLenBatch": targetLenBatch.long(),
# 'maskw2v': True,
# }

return {
'input_ids': input_ids, #torch.Size([4, 114])
'labels': labels, #torch.Size([4, 114])
'attention_mask': attention_mask, #torch.Size([4, 114])
# 'audio_mel': audio_mel,
# 'audio_mel_post_mask': audio_mel_post_mask,
'audio_mask': audio_mask,
'modality_mask': modality_mask,

"audio": inputBatch[0], #torch.Size([4, 92800])
"audiomask": inputBatch[1], #torch.Size([4, 92800])
"audio_mask": inputBatch[1], #torch.Size([4, 92800])
"visual": inputBatch[2], #torch.Size([4, 146, 1, 112, 112])
"vis_len": inputBatch[3], #torch.Size([4])

"targetoutBatch": targetoutBatch, #torch.Size([4, 50])
"targetLenBatch": targetLenBatch.long(), #torch.Size([4])
'maskw2v': True,
}

def pad(self, sequence, max_length, padding_idx=0):
Expand Down
6 changes: 3 additions & 3 deletions src/llama_recipes/datasets/speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,17 @@ def collator(self, samples):
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1

audio_mask = torch.zeros_like(attention_mask)
modality_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
audio_mask[line, :sample['audio_length']] = 1
modality_mask[line, :sample['audio_length']] = 1

return {
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'audio_mel': audio_mel,
'audio_mel_post_mask': audio_mel_post_mask,
'audio_mask': audio_mask
'modality_mask': modality_mask
}


Expand Down
6 changes: 3 additions & 3 deletions src/llama_recipes/datasets/speech_dataset_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def collator(self, samples):
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1

audio_mask = torch.zeros_like(attention_mask)
modality_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
audio_mask[line, :sample['audio_length']] = 1
modality_mask[line, :sample['audio_length']] = 1
keys = [s['key'] for s in samples]
targets = [s['target'] for s in samples]

Expand All @@ -138,7 +138,7 @@ def collator(self, samples):
'attention_mask': attention_mask,
'audio_mel': audio_mel,
'audio_mel_post_mask': audio_mel_post_mask,
'audio_mask': audio_mask,
'modality_mask': modality_mask,
'keys': keys,
'targets': targets
}
Expand Down
27 changes: 12 additions & 15 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,13 @@ def forward(self,
audio_mel = kwargs.get("audio_mel", None)
audio_mel_mask = kwargs.get("audio_mel_mask", None)
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
audio_mask = kwargs.get("audio_mask", None)

audio = kwargs.get("audio", None) #torch.Size([2, 96480])
audiomask = kwargs.get("audiomask", None) #删 #torch.Size([2, 96480])
visual = kwargs.get("visual", None) #torch.Size([2, 151, 1, 112, 112])
vis_len = kwargs.get("vis_len", None) #tensor([ 77, 151], device='cuda:0', dtype=torch.int32)
maskw2v = kwargs.get("maskw2v", None) #True
targetoutBatch = kwargs.get("targetoutBatch", None) #torch.Size([2, 29])
targetLenBatch = kwargs.get("targetLenBatch", None) #tensor([18, 29], device='cuda:0')
modality_mask = kwargs.get("modality_mask", None)

audio = kwargs.get("audio", None)
audio_mask = kwargs.get("audio_mask", None)
visual = kwargs.get("visual", None)
vis_len = kwargs.get("vis_len", None)
maskw2v = kwargs.get("maskw2v", False) #(FIX:MZY) False for supervised learning and inference


encoder_outs = None
Expand All @@ -204,7 +201,7 @@ def forward(self,
if self.model_config.encoder_name == "beats":
encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim
if self.model_config.encoder_name == "moco_wav2vec2":
encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audiomask, visual, vis_len) ,maskw2v) # bs*seq*dim
encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audio_mask, visual, vis_len) ,maskw2v) # bs*seq*dim

if self.model_config.encoder_projector == "q-former":
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
Expand All @@ -220,11 +217,11 @@ def forward(self,
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)

if audio_mask is not None:
if modality_mask is not None:
batch_size, token_num, dims = inputs_embeds.shape
_, l, _ = encoder_outs.shape
encoder_outs_pad = F.pad(encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0)
inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
inputs_embeds = encoder_outs_pad * modality_mask[:, :, None] + inputs_embeds * (~modality_mask[:, :, None])

model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)

Expand Down Expand Up @@ -253,7 +250,7 @@ def generate(self,
audio_mel = kwargs.get("audio_mel", None)
audio_mel_mask = kwargs.get("audio_mel_mask", None)
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
audio_mask = kwargs.get("audio_mask", None)
modality_mask = kwargs.get("modality_mask", None)

encoder_outs = None
if audio_mel is not None:
Expand All @@ -276,11 +273,11 @@ def generate(self,
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)

if audio_mask is not None:
if modality_mask is not None:
batch_size, token_num, dims = inputs_embeds.shape
_, l, _ = encoder_outs.shape
encoder_outs_pad = F.pad(encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0)
inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
inputs_embeds = encoder_outs_pad * modality_mask[:, :, None] + inputs_embeds * (~modality_mask[:, :, None])

model_outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
Expand Down
42 changes: 42 additions & 0 deletions src/llama_recipes/utils/compute_ppl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
import json

MODEL_PATH = "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)

device = 'cuda:6'
model.to(device)
model.eval()

corpus_path = "/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl"
corpus = []
with open(corpus_path, encoding='utf-8') as fin:
for line in fin:
data_dict = json.loads(line.strip())
corpus.append(data_dict.get("target", None))

cumulative_log_likelihood = 0
total_tokens = 0

for sentence in tqdm(corpus):
inputs = tokenizer(sentence, return_tensors="pt").to(device)

input_ids = inputs["input_ids"]
input_len = input_ids.size(1)
total_tokens += input_len

with torch.no_grad():
outputs = model(**inputs, labels=input_ids)
log_likelihood = outputs.loss * input_len
cumulative_log_likelihood += log_likelihood.item()


average_log_likelihood = cumulative_log_likelihood / total_tokens
corpus_ppl = torch.exp(torch.tensor(average_log_likelihood)).item()

print(f"Model: {MODEL_PATH}")
print(f"Corpus: {corpus_path}")
print(f"Corpus Perplexity: {corpus_ppl}")
Loading