Skip to content

Commit

Permalink
Merge pull request #12 from ddlBoJack/dev-mzy
Browse files Browse the repository at this point in the history
update q-former
ddlBoJack authored Dec 15, 2023
2 parents 6a792c8 + 6316b72 commit 297d65a
Showing 8 changed files with 134 additions and 74 deletions.
30 changes: 15 additions & 15 deletions scripts/finetune_speech_pretraining.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

@@ -14,7 +14,7 @@ cd /root/SLAM-LLM
speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048
output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-qformer64-proj2048-lr1e-5-whisper-test

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
@@ -27,21 +27,21 @@ python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/fin
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--encoder_projector q-former \
--dataset custom_dataset \
--custom_dataset.fix_length_audio 64 \
--custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \
--custom_dataset.train_data_path /nfs/beinian.lzr/workspace/datasets/speech_llm/train_dataset/data_wav_json/asr/librispeech_train_960h_wav_speech_llm_train_data.json \
--custom_dataset.val_data_path /nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/dev_other/librispeech_dev_other.jsonl \
--custom_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.trans.jsonl \
--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 4 \
--val_batch_size 4 \
--lr 1e-5 \
--output_dir $output_dir \
--run_test_during_validation \
--run_test_during_validation_file "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav" \
--run_test_during_validation_prompt "<|ASR|>" \
--run_test_during_validation_file "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0000.wav" \
--run_test_during_validation_prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \
--metric acc \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
@@ -54,7 +54,7 @@ torchrun \
src/llama_recipes/pipeline/finetune.py \
--model_name asr \
--freeze_encoder \
--freeze_llm \
--use_peft --peft_method lora \
--use_fp16 \
--enable_fsdp \
--llm_name llama-2-7b-hf \
@@ -66,22 +66,22 @@ src/llama_recipes/pipeline/finetune.py \
--encoder_projector_ds_rate 5 \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \
--custom_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h_wav_speech_llm_train_data.json \
--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other.jsonl \
--custom_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.trans.jsonl \
--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 8 \
--val_batch_size 8 \
--batch_size_training 16 \
--val_batch_size 16 \
--num_workers_dataloader 4 \
--lr 1e-5 \
--output_dir $output_dir \
--run_test_during_validation \
--run_test_during_validation_file "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0000.wav" \
--run_test_during_validation_prompt "<|ASR|>" \
--run_test_during_validation_prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \
--metric acc \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
# --use_peft --peft_method lora \
# --freeze_llm \
fi

# {"key": "1001-134707-0000_ASR", "prompt": "<ASR>", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": "<ASR>"}
4 changes: 2 additions & 2 deletions scripts/inference_asr_batch.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM
@@ -26,7 +26,7 @@ python src/llama_recipes/pipeline/inference_batch.py \
--encoder_projector_ds_rate 5 \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/speech_dataset_inference.py:get_audio_dataset \
--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_test_other.jsonl \
--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 1 \
--val_batch_size 8 \
3 changes: 2 additions & 1 deletion src/llama_recipes/configs/datasets.py
Original file line number Diff line number Diff line change
@@ -36,4 +36,5 @@ class custom_dataset:
train_data_path: str = NotImplemented
val_data_path: str = NotImplemented
max_words: int = NotImplemented
max_mel: int = NotImplemented
max_mel: int = NotImplemented
fix_length_audio: int = -1
21 changes: 12 additions & 9 deletions src/llama_recipes/datasets/speech_dataset.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,8 @@ def __init__(self,
# self.data_list = contents
self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
self.prompt_template = "USER: {}\n ASSISTANT:"
self.answer_template = "<|{}|>"
self.answer_template = "{}"
self.fix_length_audio = dataset_config.fix_length_audio

self.data_list = []
if split == "train":
@@ -69,12 +70,7 @@ def __getitem__(self, index):
speech_raw = whisper.load_audio(speech_path)
speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1, 0)

prompt = """
<|ASR|>
"""
answer = """
<|The moon looks so beautiful tonight.|>
"""
prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "

prompt = self.prompt_template.format(prompt)
answer = self.answer_template.format(target)
@@ -83,8 +79,11 @@ def __getitem__(self, index):

prompt_length = len(prompt_ids)
speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
speech_length = speech_length // 5 # ad-hoc for 5x cov1d downsample
speech_pseudo = torch.full((speech_length,), -1)
speech_length = speech_length // 5 # ad-hoc for 5x fc downsample
# speech_length = calculate_output_length_1d(speech_length, 5, 5, 0) # ad-hoc for 5x cov1d downsample
if self.fix_length_audio > 0:
speech_length = self.fix_length_audio
speech_pseudo = torch.full((speech_length,), -1) # placeholder

example = prompt + answer # FIX(MZY): avoid putting a bos token before answer.
example_ids = self.tokenizer.encode(example) # [prompt,answer]
@@ -140,6 +139,9 @@ def collator(self, samples):
speech_mel_max_length = max([s['speech_mel'].shape[0] for s in samples])
speech_mel = torch.stack([self.pad(s['speech_mel'], speech_mel_max_length, 0)
for s in samples])
speech_mel_post_mask = torch.zeros(len(samples), (speech_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats
for line, sample in enumerate(samples):
speech_mel_post_mask[line, :(sample['speech_mel'].shape[0] + 1) // 2] = 1

speech_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
@@ -150,6 +152,7 @@ def collator(self, samples):
'labels': labels,
'attention_mask': attention_mask,
'speech_mel': speech_mel,
'speech_mel_post_mask': speech_mel_post_mask,
'speech_mask': speech_mask
}

7 changes: 7 additions & 0 deletions src/llama_recipes/datasets/speech_dataset_inference.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ def __init__(self,
# self.data_list = contents
self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
self.prompt_template = "USER: {}\n ASSISTANT:"
self.fix_length_audio = dataset_config.fix_length_audio

self.data_list = []
if split == "train":
@@ -71,6 +72,8 @@ def __getitem__(self, index):

speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
speech_length = speech_length // 5 # ad-hoc for 5x cov1d downsample
if self.fix_length_audio > 0:
speech_length = self.fix_length_audio
speech_pseudo = torch.full((speech_length,), -1)

prompt = """
@@ -120,6 +123,9 @@ def collator(self, samples):
speech_mel_max_length = max([s['speech_mel'].shape[0] for s in samples])
speech_mel = torch.stack([self.pad(s['speech_mel'], speech_mel_max_length, 0)
for s in samples])
speech_mel_post_mask = torch.zeros(len(samples), (speech_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats
for line, sample in enumerate(samples):
speech_mel_post_mask[line, :(sample['speech_mel'].shape[0] + 1) // 2] = 1

speech_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
@@ -131,6 +137,7 @@ def collator(self, samples):
'input_ids': input_ids,
'attention_mask': attention_mask,
'speech_mel': speech_mel,
'speech_mel_post_mask': speech_mel_post_mask,
'speech_mask': speech_mask,
'keys': keys,
'targets': targets
73 changes: 73 additions & 0 deletions src/llama_recipes/models/projector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.nn as nn


class EncoderProjectorConcat(nn.Module):
def __init__(self, config):
super().__init__()
self.k = config.encoder_projector_ds_rate
self.linear1 = nn.Linear(1280 * self.k, 2048)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2048, 4096)

def forward(self, x):
batch_size, seq_len, dim = x.size()
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)

x = x.view(batch_size, seq_len // self.k, dim * self.k)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x

class EncoderProjectorCov1d(nn.Module):
def __init__(self, config):
super().__init__()
self.conv1d = nn.Conv1d(in_channels=1280, out_channels=1280, kernel_size=config.encoder_projector_ds_rate, stride=config.encoder_projector_ds_rate, padding=0)
self.linear1 = nn.Linear(1280, 2048)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(2048, 4096)
self.relu2 = nn.ReLU()

def forward(self, x):
x = x.transpose(1, 2)
x = self.conv1d(x)
x = x.transpose(1, 2)
x = self.relu1(x)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
return x

class EncoderProjectorQFormer(nn.Module):
def __init__(self, config):
super().__init__()
from transformers import Blip2QFormerConfig, Blip2QFormerModel
configuration = Blip2QFormerConfig()
configuration.encoder_hidden_size = 1280
configuration.num_hidden_layers = 2

self.query_len = 64
self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
self.query.data.normal_(mean=0.0, std=1.0)
self.qformer = Blip2QFormerModel(configuration)

self.linear = nn.Linear(configuration.hidden_size, 4096)
self.norm = nn.LayerNorm(4096, eps=1e-5)

def forward(self, x, atts):
query = self.query.expand(x.shape[0], -1, -1)

query_output = self.qformer(
query_embeds=query,
encoder_hidden_states=x,
encoder_attention_mask=atts,
return_dict=True,
)

query_proj = self.norm(self.linear(query_output.last_hidden_state))

return query_proj
62 changes: 18 additions & 44 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,8 @@
from torch.nn import CrossEntropyLoss
from llama_recipes.utils.metric import compute_accuracy

from llama_recipes.models.projector import EncoderProjectorConcat, EncoderProjectorCov1d, EncoderProjectorQFormer


def setup_model(tokenizer, train_config, model_config, **kwargs):
return slam_model(tokenizer, train_config, model_config, **kwargs)
@@ -145,49 +147,13 @@ def setup_llm(train_config, model_config, **kwargs):
def setup_encoder_projector(train_config, model_config, **kwargs):
if model_config.encoder_projector == "linear":
encoder_projector = EncoderProjectorConcat(model_config)
elif model_config.encoder_projector == "cov1d-linear":
encoder_projector = EncoderProjectorCov1d(model_config)
elif model_config.encoder_projector == "q-former":
encoder_projector = EncoderProjectorQFormer(model_config)
print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp else 0)
return encoder_projector

class EncoderProjectorConcat(nn.Module):
def __init__(self, config):
super().__init__()
self.k = config.encoder_projector_ds_rate
self.linear1 = nn.Linear(1280 * self.k, 2048)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2048, 4096)

def forward(self, x):
batch_size, seq_len, dim = x.size()
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)

x = x.view(batch_size, seq_len // self.k, dim * self.k)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x

class EncoderProjectorCov1d(nn.Module):
def __init__(self, config):
super(self).__init__()
self.conv1d = nn.Conv1d(in_channels=1280, out_channels=1280, kernel_size=config.encoder_projector_ds_rate, stride=config.encoder_projector_ds_rate, padding=0)
self.linear1 = nn.Linear(1280, 2048)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(2048, 4096)
self.relu2 = nn.ReLU()

def forward(self, x):
x = x.transpose(1, 2)
x = self.conv1d(x)
x = x.transpose(1, 2)
x = self.relu1(x)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
return x


class slam_model(nn.Module):
def __init__(
@@ -228,12 +194,16 @@ def forward(self,
**kwargs,
):
speech_mel = kwargs.get("speech_mel", None)
speech_mel_post_mask = kwargs.get("speech_mel_post_mask", None)
speech_mask = kwargs.get("speech_mask", None)

encoder_outs = None
if speech_mel is not None:
encoder_outs = self.encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) # bs*seq*dim
encoder_outs = self.encoder_projector(encoder_outs)
if self.model_config.encoder_projector == "q-former":
encoder_outs = self.encoder_projector(encoder_outs, speech_mel_post_mask)
else:
encoder_outs = self.encoder_projector(encoder_outs)

if input_ids is not None:
input_ids[input_ids == -1] = 0
@@ -275,12 +245,16 @@ def generate(self,
**kwargs,
):
speech_mel = kwargs.get("speech_mel", None)
speech_mel_post_mask = kwargs.get("speech_mel_post_mask", None)
speech_mask = kwargs.get("speech_mask", None)

encoder_outs = None
if speech_mel is not None:
encoder_outs = self.encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) # bs*seq*dim
encoder_outs = self.encoder_projector(encoder_outs)
if self.model_config.encoder_projector == "q-former":
encoder_outs = self.encoder_projector(encoder_outs, speech_mel_post_mask)
else:
encoder_outs = self.encoder_projector(encoder_outs)

if input_ids is not None:
input_ids[input_ids == -1] = 0
@@ -300,7 +274,7 @@ def generate(self,
model_outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
max_length=kwargs.get("max_length", 200),
num_beams=kwargs.get("num_beams", 4),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", False),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 0.9),
@@ -330,7 +304,7 @@ def inference(
negative_prompt_ids = None,
negative_prompt_attention_mask = None,
**kwargs,
):
): # TODO: Now you need to set your customized sampling rate manually

device = kwargs.get("device", "cuda")
assert os.path.exists(wav_path)
8 changes: 5 additions & 3 deletions src/llama_recipes/pipeline/inference_batch.py
Original file line number Diff line number Diff line change
@@ -53,15 +53,17 @@ def main(**kwargs):


print("=====================================")
with open(kwargs.get('decode_log'), "w") as decode_log:
pred_path = kwargs.get('decode_log') + "_pred_other"
gt_path = kwargs.get('decode_log') + "_gt_other"
with open(pred_path, "w") as pred, open(gt_path, "w") as gt:
for step, batch in enumerate(test_dataloader):
for key in batch.keys():
batch[key] = batch[key].to(device) if key not in ["keys", "targets"] else batch[key]
model_outputs = model.generate(**batch)
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
for key, text, target in zip(batch["keys"], output_text, batch["targets"]):
decode_log.write(key + "\t" + text + "\n")
decode_log.write(key + "\t" + target + "\n")
pred.write(key + "\t" + text.replace("\n", " ") + "\n")
gt.write(key + "\t" + target + "\n")


if __name__ == "__main__":

0 comments on commit 297d65a

Please sign in to comment.