Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ctc-assisted llm-basd CASR codes pr #172

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions examples/contextual_asr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# CTC-Assisted LLM-Based Contextual ASR

## Guides

[CTC-Assisted LLM-Based Contextual ASR](https://arxiv.org/abs/2411.06437) is an LLM-based contextual ASR model that first uses CTC decoding results to filter potential relevant hotwords from pre-defined hotwords list and then incorporate them into LLM prompt input to improve recognition of hotwords.

## Model Architecture

We use WavLM-Large model pre-trained on 94, 000 hours of data, and fine-tuned on 960h hours of Librispeech data with CTC loss, as our speech encoder. We use the public Vicuna 7B as our large language model decoder, and a simple-structured linear projector, consisting of a 1-D convolution layer and two linear layers as our adapter. Refer to our [paper](https://arxiv.org/pdf/2411.06437) for more details.

![](docs/model.pdf)

## Checkpoints
We only train the linear projector in this recipe.
Encoder | Projector | LLM
|---|---|---|
[CTC Fine-tuned WavLM-Large](https://drive.google.com/file/d/12ZmSSbDvx73W0eK1wpUgajapCLhqh5DI/view?usp=drive_link)(~315.45M) | [Linear](https://drive.google.com/file/d/1Zlbsnz1YUWtYtt-yNyoPK5OhR30kwLfS/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5)(~6.7B)

## Performance
![](docs/performanc.png)


## Data preparation
The artificial biasing list constructed in [Contextualized streaming end-to-end speech recognition with trie-based deep biasing and shallow fusion](https://arxiv.org/pdf/2104.02194) is utilized for contextual ASR testing. Refer to official [Repo](https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias).
They categorize the 5,000 most frequent words in the Librispeech training corpus as common
words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists.



## Decoding with checkpoints
```
bash decode_wavlm_libri960_ft_char_hotwords_filter_N100_F3.sh
```

## Training the model
```
bash finetune_wavlm_libri960_ft_char_hotwords.sh
bash finetune_wavlm_libri960_ft_char.sh
```

## Citation
You can refer to the paper for more results.
```
@article{yang2024ctc,
title={CTC-Assisted LLM-Based Contextual ASR},
author={Yang, Guanrou and Ma, Ziyang and Gao, Zhifu and Zhang, Shiliang and Chen, Xie},
journal={Proc. SLT},
year={2024}
}
```


19 changes: 19 additions & 0 deletions examples/contextual_asr/conf/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
4 changes: 4 additions & 0 deletions examples/contextual_asr/conf/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dataset_config:
# we put prompt here, because the hydra override in shell script only support a small subset of chars
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
prompt: "Transcribe speech to text. "
137 changes: 137 additions & 0 deletions examples/contextual_asr/contextual_asr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class ModelConfig:
file: str = "examples/contextual_asr/model/slam_model_contextual_asr.py:model_factory"
llm_name: str = "vicuna-13b-v1.5"
llm_path: str = "PATH/to/LLAMA/7B"
llm_type: str = "decoder_only"
llm_dim: int = 4096
encoder_name: Optional[str] = None
encoder_ds_rate: int = 2
encoder_path: Optional[str] = None
encoder_dim: int = 1280
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
modal: str = "audio"
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})
encoder_type: str = field(default="finetune", metadata={
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
})

@dataclass
class PeftConfig:
peft_method: str = "lora" # None , llama_adapter, prefix
r: int = 8
lora_alpha: int = 32
# target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj","k_proj","o_proj" ])
bias: str = "none"
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False

@dataclass
class TrainConfig:
model_name:str = "PATH/to/LLAMA/7B"
enable_ddp:bool = False
enable_deepspeed:bool = False
enable_fsdp:bool = False
low_cpu_fsdp:bool = False
run_validation:bool = True
batch_size_training:int = 4
batching_strategy:str = field(default="packing", metadata={
"help":"alternative: padding"
})
context_length:int = 4096
gradient_accumulation_steps:int = 1
num_epochs:int = 3
num_workers_dataloader:int = 1
warmup_steps:int = 1000
total_steps:int = 100000
validation_interval:int = 1000
lr:float = 1e-4
weight_decay:float = 0.0
gamma:float = 0.85
seed:int = 42
use_fp16:bool = False
mixed_precision:bool = True
val_batch_size:int = 1
use_peft:bool = False
peft_config:PeftConfig = field(default_factory=PeftConfig)
output_dir:str = "PATH/to/save/PEFT/model"
freeze_layers:bool = False
num_freeze_layers:int = 1
quantization:bool = False
one_gpu:bool = False
save_model:bool = True
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
save_optimizer:bool = False # will be used if using FSDP
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation:bool = False
run_test_during_validation_file:str = "test.wav"
run_test_during_validation_prompt:str = "<|ASR|>"
freeze_llm:bool = field(default=False, metadata={
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
})
freeze_encoder:bool = False

@dataclass
class DataConfig:
dataset: str = "speech_dataset"
file: str = "examples/contextual_asr/dataset/hotwords_dataset.py:get_speech_dataset"
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
train_split: str = "train"
test_split:str = "validation"
prompt: Optional[str] = None
data_path: Optional[str] = None
max_words: Optional[int] = None
max_mel: Optional[float] = None
fix_length_audio: int = -1
inference_mode:bool = False
input_type: str = field(default="raw", metadata={
"help":"Use raw when input is wav, mel when for whisper"
})
mel_size: int = field(default=80, metadata={
"help": "80 for whisper large v1 and v2, 128 for v3"
})
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})
infer_type: str = "bias"
infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv"
ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt"
filter_type: str = "char"
phn_to_name_dict: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_phn.json"
common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt"
probability_threshold: float = 0.9
word_num: int = 15
filter_infer_sentence: bool = False
filter_infer_sentence_few: bool = False
first: int = 1

@dataclass
class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
pure_bf16: bool = False
optimizer: str = "AdamW"

@dataclass
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "/root/test_wandb"
wandb_entity_name: str = "project_name"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "/root/test.log"
log_interval: int = 5
Loading
Loading