Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

self-supervised pretraining(wav2vec 2.0/data2vec) for wenet #1003

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
65 changes: 65 additions & 0 deletions examples/aishell/s0_ssl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# w2v-conformer

This is a example to use unsupervised pretrained w2v-conformer model to fintune Aishell task.

We pretrain conformer encoders using wav2vec 2.0 pre-training method and we use fbank features as inputs.

The w2v-conformer model uses ISML datasets to pretrain, this is a internal dataset contains 60k hours Chinese.


## pretraining :

We use two model configurations to pretrain the conformer encoder architecture:

Base model contains 12 conformer blocks, model dimension 512, FFN dimension 2048 and 8 attention heads.
samples are batched together to not exceed 30000 frames per GPU. we train a total of 32 V100 GPUs for 800k steeps.

Middle model contains 24 conformer blocks with model dimension 2048, FFN dimension 512 and 8 attention heads. We add a reconstruction loss to slightly improve performance. To speed up training procedure, we change The time stride of convolutional subsampling blocks to 3, so the length of the input feature becomes one sixth. samples are batched together to not exceed 20000 frames per GPU. we train a total of 32 V100 GPUs for 600k steps.

We are also trying to train the causal model for u2 training and large model with 300m parameters, and this work is ongoing.

pretrained model link:
| model | Architecture | Model |
|-------------------|----|----|
| Base | 90Mb | -
| Middle | 180M | - |



## finetuning tips:

* After pretraining, we can build encoder-decoder based ASR system.The conformer based encoder takes the pretrained model as initialization and the transformer based decoder will be trained from scratch. Just set --enc_init_mods like 'encoder.embed.,encoder.encoders.0.,encoder.encoders.1. ...' to load customized pretrained parameters.

* In aishell task, we carefully adjust the learning rate to 0.0004~0.0005 to get best performence we also find that if too many layers are set for decoder,the migration performance of the pre-training model will be degraded, so we only build a small transformer decoder for joint training. If the downstream task is more than 500 hours, you can increase the learning rate and the parameter amount of the decoder.

* Please note that the final layer of the pretraining model do not provide a good initialization for fine-tuning and would benefit from being re-initialized before fine-tuning.

# Base model performance

## Conformer Result

* config: conf/train_conformer_base_100h.yaml
* Training info: lr 0.0004, batch size 16, 4 gpus on A100, acc_grad 1, 250 epochs
* Decoding info: ctc_weight 0.5, average_num 35

| decoding mode | CER |
|---------------------------|-------|
| ctc greedy search | 3.86 |
| ctc prefix beam search | 3.86 |
| attention rescoring | 3.79 |

# Middle model performance

## Conformer Result

* config: conf/train_conformer_large_100h.yaml
* Training info: lr 0.0005, batch size 16, 4 gpus on A100, acc_grad 1, 250 epochs
* Decoding info: ctc_weight 0.5, average_num 35

| decoding mode | CER |
|---------------------------|-------|
| ctc greedy search | 3.46 |
| ctc prefix beam search | 3.46 |
| attention rescoring | 3.37 |


77 changes: 77 additions & 0 deletions examples/aishell/s0_ssl/conf/train_base_conformer_100h.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 8
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 512
num_blocks: 2
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

# hybrid CTC/attention
model_conf:
ctc_weight: 0.7
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 2
max_epoch: 240
log_interval: 100

optim: adam
optim_conf:
lr: 0.0004
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
77 changes: 77 additions & 0 deletions examples/aishell/s0_ssl/conf/train_conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 4
max_epoch: 240
log_interval: 100

optim: adam
optim_conf:
lr: 0.002
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
83 changes: 83 additions & 0 deletions examples/aishell/s0_ssl/conf/train_u2++_conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 8
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false

# decoder related
decoder: bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 3
r_num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3

dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 1
max_epoch: 360
log_interval: 100

optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
81 changes: 81 additions & 0 deletions examples/aishell/s0_ssl/conf/train_unified_conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 1
max_epoch: 180
log_interval: 100

optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
Loading