Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Consistency-Regularized CTC #1766

Merged
merged 6 commits into from
Oct 21, 2024
Merged

Add Consistency-Regularized CTC #1766

merged 6 commits into from
Oct 21, 2024

Conversation

yaozengwei
Copy link
Collaborator

@yaozengwei yaozengwei commented Oct 8, 2024

This PR implements the Consistency-Regularized CTC (CR-CTC) in https://arxiv.org/pdf/2410.05101,
which enforces consistency between two CTC distributions obtained from different augmented views of the input speech mel-spectrogram. It significantly improves the CTC performance, and could also be an auxiliary loss to boost the performance of transducer or CTC/AED. Please see paper for more details.

@yaozengwei
Copy link
Collaborator Author

On LibriSpeech dataset, results comparison with Zipformer, without using an external language model:

Model Params (M) test-clean test-other
CTC/AED, Zipformer-S 46.3 2.46 6.04
CTC/AED, Zipformer-M 90.0 2.22 4.97
CTC/AED, Zipformer-L 174.3 2.09 4.59
Pruned transducer, Zipformer-S 23.3 2.42 5.73
Pruned transducer, Zipformer-M 65.6 2.21 4.79
Pruned transducer, Zipformer-L 148.4 2.00 4.38
CTC, Zipformer-S 22.1 2.85 6.89
CTC, Zipformer-M 64.3 2.52 6.02
CTC, Zipformer-L 147.0 2.5 5.72
CR-CTC, Zipformer-S 22.1 2.52 5.85
CR-CTC, Zipformer-M 64.3 2.1 4.61
CR-CTC, Zipformer-L 147.0 2.02 4.35
CR-CTC/AED, Zipformer-L 174.3 1.96 4.08
Pruned transducer w/ CR-CTC, Zipformer-L 148.8 1.88 3.95

@csukuangfj
Copy link
Collaborator

Could you update RESULTS.md to include the URLs for the checkpoints and training logs of your PR?

@yaozengwei
Copy link
Collaborator Author

yaozengwei commented Oct 8, 2024

Could you update RESULTS.md to include the URLs for the checkpoints and training logs of your PR?

Sure. Will do it later.

@@ -950,7 +943,6 @@ def compute_loss(
spec_augment=spec_augment,
supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can not find the definition of spec_aug_time_warp_factor

Copy link
Collaborator Author

@yaozengwei yaozengwei Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is defined in zipformer/asr_datamodule.py

@yaozengwei
Copy link
Collaborator Author

yaozengwei commented Oct 9, 2024

An example of training script using 4 * 32G-V100:

export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
  --world-size 4 \
  --num-epochs 50 \
  --start-epoch 1 \
  --use-fp16 1 \
  --exp-dir zipformer/exp-cr-loss-scale-0.2-time-mask-ratio-2.5 \
  --use-cr-ctc 1 \
  --use-ctc 1 \
  --use-transducer 0 \
  --use-attention-decoder 0 \
  --enable-spec-aug 0 \
  --cr-loss-scale 0.2 \
  --time-mask-ratio 2.5 \
  --full-libri 1 \
  --max-duration 700 \
  --master-port 12345

@yaozengwei
Copy link
Collaborator Author

I have uploaded the checkpoints and updated RESULTS.md. @pkufool will make a PR for adding ctc-prefix-decoding.

Copy link
Collaborator

@pkufool pkufool left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pkufool pkufool merged commit 693d84a into k2-fsa:master Oct 21, 2024
75 of 108 checks passed
@yaozengwei
Copy link
Collaborator Author

I did some finetuning exps:

  • The initialized weights are from the models trained on GigaSpeech, using Transducer loss or CR-CTC loss
  • Finetune a Transducer model on LibriSpeech, only initialize the encoder (so the decoder and joiner are randomly initialized)

Results on GigaSpeech:

  • Zipformer-L, Transducer, 10.23, 10.28
  • Zipformer-L, CR-CTC, 10.31, 10.41

Finetuned results on LibriSpeech:

  • finetune on train-clean-100:
    Initialize with Transducer-trained encoder, epoch-5: 3.42, 7.45; epoch-10: 3.24, 7.36
    Initialize with CR-CTC-trained encoder, epoch-5: 3.12, 7.03; epoch-10: 3.18, 7.06
  • finetune on full-libri:
    Initialize with Transducer-trained encoder, epoch-5: 2.04, 4.57; epoch-10: 1.99, 4.39
    Initialize with CR-CTC-trained encoder, epoch-5: 1.99, 4.35; epoch-10: 1.97, 4.33

The results show that CR-CTC could be a good choice for pretraining.

@xiaoxi91
Copy link

First of all, I would like to express my deepest gratitude for sharing your invaluable code and paper. They have been immensely helpful in my research endeavors. While reading through your paper and exploring the code, I have encountered a question concerning the batch_size setting, and I would appreciate your insights.

In your paper, you mention that "As CR-CTC requires two forward pass during training, we train CR-CTC models with half the batch size and half the number of epochs compared to CTC models, ensuring a fair comparison in terms of training cost". However, in the model.py file, I noticed that the forward function scale the ctc_loss and transducer_loss by 0.5. I wonder do I need to continue adjusting the setting of batch_size(max_duration) ?

Once again, thank you for your hard work and generous sharing!
Best regards

@yaozengwei
Copy link
Collaborator Author

yaozengwei commented Oct 29, 2024

First of all, I would like to express my deepest gratitude for sharing your invaluable code and paper. They have been immensely helpful in my research endeavors. While reading through your paper and exploring the code, I have encountered a question concerning the batch_size setting, and I would appreciate your insights.

In your paper, you mention that "As CR-CTC requires two forward pass during training, we train CR-CTC models with half the batch size and half the number of epochs compared to CTC models, ensuring a fair comparison in terms of training cost". However, in the model.py file, I noticed that the forward function scale the ctc_loss and transducer_loss by 0.5. I wonder do I need to continue adjusting the setting of batch_size(max_duration) ?

Once again, thank you for your hard work and generous sharing! Best regards

For example, if you use max-duration of 1400 for standard CTC, you could use max-duration of 700 for CR-CTC. It will create two copies and then concat them along the batch dim. The reason why we scale the loss values by 0.5 is to keep the logging loss values comparable to other setups (without CR-CTC), as we get the info["frames"] in train.py (before batch duplicating) and normalize the loss values by that before printing. You could refer to the script examples in RESULTS.md.

@zhangwenkai-orion
Copy link

Are there any results in streaming ASR? My experiments on streaming ASR using CTC seem to not be working. The CTC loss gets worse while the CR loss gets better, WER gets worse.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants