Skip to content

Commit

Permalink
Merge pull request #179 from X-LANCE/yxdu
Browse files Browse the repository at this point in the history
Yxdu
  • Loading branch information
ddlBoJack authored Nov 27, 2024
2 parents 6c26585 + f887fc7 commit 8d2dc88
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 124,300 deletions.
15,529 changes: 0 additions & 15,529 deletions examples/st_covost2/covost2_zh.jsonl

This file was deleted.

53 changes: 11 additions & 42 deletions examples/st_covost2/dataset/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,12 @@ def __init__(self,
super().__init__()
self.mel_size = dataset_config.get("mel_size", 80) # 80 for whisper large v1 and v2, 128 for large v3

rank = dist.get_rank()


data_name = "yxdu/covost2_en_x"
local_dataset_path= data_name.split("/")[-1]+"_"+split+"_cache"

if os.path.exists(local_dataset_path):
ds = load_from_disk(local_dataset_path)
print(ds)
else:
if rank==0:
ds = load_dataset(data_name, split=split)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
print(ds)



def prepare_dataset(example):
audio_raw = whisper.pad_or_trim(example["audio"]["array"])

audio_raw = torch.tensor(audio_raw, dtype=torch.float32)
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0)

example["audio_mel"] = audio_mel


return example

ds = ds.map(prepare_dataset, remove_columns="audio")

ds.save_to_disk(local_dataset_path)

dist.barrier()
if rank != 0:
if os.path.exists(local_dataset_path):
ds = load_from_disk(local_dataset_path)
else:
raise FileNotFoundError("No Dataset。")



if split=="val":
split="validation"
ds = load_dataset("yxdu/covost2_en_x",split=split)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
print(ds)


self.ds = ds
self.tokenizer = tokenizer
Expand Down Expand Up @@ -111,8 +76,12 @@ def __getitem__(self, index):
print(target)
self.printed = True

audio_raw = whisper.pad_or_trim(data_dict["audio"]["array"])
audio_raw = torch.tensor(audio_raw, dtype=torch.float32)
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0)

if self.bf16:
audio_mel = torch.tensor(data_dict["audio_mel"], dtype=torch.bfloat16)
audio_mel = audio_mel.to(torch.bfloat16)


if self.fix_length_audio > 0:
Expand Down
10 changes: 6 additions & 4 deletions examples/st_covost2/inference_asr_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ def __len__(self):
def Inference(kwargs: DictConfig):

# Update the configuration for the training and sharding process
train_config, fsdp_config, model_config, log_config, dataset_config = kwargs.train_config, \
train_config, fsdp_config, model_config, log_config, dataset_config,ckpt_path = kwargs.train_config, \
kwargs.fsdp_config, \
kwargs.model_config, \
kwargs.log_config, \
kwargs.dataset_config
kwargs.dataset_config, \
kwargs.ckpt_path

OmegaConf.set_struct(kwargs,False)
del kwargs["train_config"]
Expand Down Expand Up @@ -114,8 +115,8 @@ def Inference(kwargs: DictConfig):

config = AutoConfig.from_pretrained("Qwen/Qwen2-7B") # 加载 Qwen2-7B 的配置
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")
model = CustomSLM(config,ckpt_path="cotst/model.pt")
model = CustomSLM(config,ckpt_path=ckpt_path)
# model = AutoModel.from_pretrained("/home/yxdu/hit/SLAM-LLM/examples/st_covost2/output/step_10/test")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device.
Expand Down Expand Up @@ -143,6 +144,7 @@ def Inference(kwargs: DictConfig):
batch_size=train_config.val_batch_size,
drop_last=False,
prefetch_factor=1000,
persistent_workers=True,
collate_fn=dataset_test.collator
)

Expand Down
20 changes: 10 additions & 10 deletions examples/st_covost2/model/slm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,10 @@ def forward(self,
audio_mel = kwargs.get("audio_mel", None)
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper


encoder_outs = self.encoder(audio_mel.permute(0, 2, 1)).last_hidden_state # bs*seq*dim
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)

input_ids = input_ids[:, 80:]

inputs_embeds = self.llm.model.embed_tokens(input_ids)
inputs_embeds = torch.cat((encoder_outs, inputs_embeds), dim=1)

Expand All @@ -80,14 +78,16 @@ def forward(self,


model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels,)
acc = -1
if self.metric:
with torch.no_grad():
preds = torch.argmax(input=model_outputs.logits, dim=-1)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)


return model_outputs, acc


with torch.no_grad():
preds = torch.argmax(input=model_outputs.logits, dim=-1)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)
print(acc)

return model_outputs

# return model_outputs, acc

@torch.no_grad()
def generate(self,
Expand Down
18 changes: 7 additions & 11 deletions examples/st_covost2/scripts/all.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# export TOKENIZERS_PARALLELISM=false
export WANDB_MODE=offline
# export HYDRA_FULL_ERROR=1

export CUDA_VISIBLE_DEVICES=0,1
if command -v nvidia-smi &> /dev/null; then
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
if [ -n "$CUDA_VISIBLE_DEVICES" ]; then
Expand All @@ -15,7 +15,7 @@ current_dir=$(dirname "$current_script")
code_dir=$(realpath "$current_dir/../../../../")
cd ${code_dir}/SLAM-LLM

source=all
source=zh

checkpoint_dir=${code_dir}/speech/data/qwen/spt-all-7B-4
output_dir=${code_dir}/speech/data/qwen/cotst-all
Expand All @@ -24,11 +24,6 @@ encoder_path_hf=${code_dir}/speech/models/whisper-large-v3
llm_path=${code_dir}/speech/models/Qwen2-7B


#change your train data
train_data_path=${code_dir}/SLAM-LLM/examples/st_covost2/test_st.jsonl
val_data_path=${code_dir}/SLAM-LLM/examples/st_covost2/test_st.jsonl




max_epoch=$(ls -d ${checkpoint_dir}/asr_epoch_*_step_* | sed -n 's/.*asr_epoch_\([0-9]*\)_step_\([0-9]*\).*/\1/p' | sort -n | tail -1)
Expand All @@ -40,7 +35,7 @@ final_path="${checkpoint_dir}/asr_epoch_${max_epoch}_step_${max_step}"


ckpt_name=$final_path/model.pt

ckpt_name=/home/yxdu/hit/SLAM-LLM/cotst/model.pt
# 使用find命令搜索所有.pt文件,并获取最后修改日期最晚的文件


Expand All @@ -62,7 +57,8 @@ hydra.run.dir=$output_dir \
++model_config.encoder_dim=1280 \
++model_config.encoder_projector=q-former \
++model_config.query_len=80 \
++dataset_config.dataset=st_dataset \
++dataset_config.dataset=hf_dataset \
++dataset_config.file=examples/st_covost2/dataset/hf_dataset.py:get_speech_dataset \
++dataset_config.train_data_path=$train_data_path \
++dataset_config.val_data_path=$val_data_path \
++dataset_config.input_type=mel \
Expand All @@ -74,7 +70,7 @@ hydra.run.dir=$output_dir \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.gradient_accumulation_steps=1 \
++train_config.gradient_accumulation_steps=8 \
++train_config.warmup_steps=1000 \
++train_config.total_steps=1000000 \
++train_config.lr=1e-5 \
Expand All @@ -101,7 +97,7 @@ torchrun \
++fsdp_config.pure_bf16=true \
++log_config.use_wandb=true \
++log_config.wandb_project_name=cot \
++train_config.validation_interval=100 \
++train_config.validation_interval=10000 \
++log_config.wandb_exp_name=all \
++train_config.use_peft=false \
$hydra_args
Expand Down
3 changes: 2 additions & 1 deletion examples/st_covost2/scripts/infer_enzh.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export MASTER_ADDR=localhost
export MASTER_PORT=12345
export WANDB_MODE=offline

export CUDA_VISIBLE_DEVICES=2,3
if command -v nvidia-smi &> /dev/null; then
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
if [ -n "$CUDA_VISIBLE_DEVICES" ]; then
Expand Down Expand Up @@ -32,6 +32,7 @@ if [ ! -f "$ckpt_path" ]; then
echo "Download ckpt..."
git clone https://huggingface.co/yxdu/cotst
fi

echo $ckpt_path


Expand Down
Loading

0 comments on commit 8d2dc88

Please sign in to comment.