Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MMS CTC Fine-Tuning #24281

Merged
merged 15 commits into from
Jun 14, 2023
Merged

Add MMS CTC Fine-Tuning #24281

merged 15 commits into from
Jun 14, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jun 14, 2023

What does this PR do?

This PR adds language adapter fine-tuning for MMS. Still playing around with good hyper-parameters but script is functional.

Getting some very nice results now for:

export CUDA_VISIBLE_DEVICES="0"
LEARNING_RATE="1e-3"

python run_speech_recognition_ctc.py \
        --dataset_name="common_voice" \
        --model_name_or_path="facebook/mms-1b-all" \
        --dataset_config_name="tr" \
        --output_dir="./wav2vec2-common_voice-tr-mms-demo" \
        --overwrite_output_dir \
        --num_train_epochs="15" \
        --per_device_train_batch_size="32" \
        --learning_rate="${LEARNING_RATE}" \
        --warmup_steps="400" \
        --evaluation_strategy="steps" \
        --text_column_name="sentence" \
        --length_column_name="input_length" \
        --save_steps="400" \
        --eval_steps="200" \
        --layerdrop="0.0" \
        --save_total_limit="3" \
        --adapter_attn_dim="16" \
        --adapter_language="tur" \
        --gradient_checkpointing \
        --chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” � \
        --fp16 \
        --group_by_length \
        --do_train --do_eval

WER drops to 25% just after 200 steps.
See: https://wandb.ai/patrickvonplaten/huggingface/runs/6f5cx5gg?workspace=user-patrickvonplaten

@sgugger @amyeroberts @sanchit-gandhi it'd be super nice to get a quick review here whether the code changes are generally fine with you. I'll only have to fill out the TODOs in the README with a nice example code and some description.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@patrickvonplaten patrickvonplaten changed the title Add mms ctc fine tuning [WIP] Add mms ctc fine tuning Jun 14, 2023
@patrickvonplaten patrickvonplaten marked this pull request as draft June 14, 2023 16:00
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 14, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -579,12 +632,24 @@ def remove_special_characters(batch):
cache_dir=model_args.cache_dir,
config=config,
use_auth_token=data_args.use_auth_token,
ignore_mismatched_sizes=True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed when instantiating from CTC checkpoints

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

@@ -31,6 +31,7 @@
import numpy as np
import torch
from datasets import DatasetDict, load_dataset
from safetensors.torch import save_file as safe_save_file
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a required dependency so should be fine

@patrickvonplaten patrickvonplaten changed the title [WIP] Add mms ctc fine tuning Add mms ctc fine tuning Jun 14, 2023
@patrickvonplaten patrickvonplaten marked this pull request as ready for review June 14, 2023 17:20
@patrickvonplaten patrickvonplaten changed the title Add mms ctc fine tuning Add MMS CTC Fine-Tuning Jun 14, 2023
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should go in its own example instead of adding some more code to the (already complex) ctc example. It's preferable to have multiple examples focused on one thing than one big multi-purpose example.

@patrickvonplaten
Copy link
Contributor Author

I think this should go in its own example instead of adding some more code to the (already complex) ctc example. It's preferable to have multiple examples focused on one thing than one big multi-purpose example.

Ok for me

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding! Changes all LGTM (filled in examples conditional ;))

+1 to @sgugger's suggestion of having a separate example

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good already, thanks for the updates @patrickvonplaten. Just some minor suggestions

adapter_attn_dim: int = field(
default=None,
metadata={
"help": "If defined, adapter layers will be randomely initialized and the rest of the model will be frozen."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"help": "If defined, adapter layers will be randomely initialized and the rest of the model will be frozen."
"help": "If defined, adapter layers will be randomly initialized and the rest of the model will be frozen."

adapter_language: Optional[str] = field(
default=None,
metadata={
"help": (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice help message!

@@ -132,6 +134,12 @@ class ModelArguments:
ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
)
adapter_attn_dim: int = field(
default=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this default to some sensible value or should we always force the user to pass it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes good point, now that things will be moved to a new file, I'll set a good default

# first we freeze the whole base model
model.freeze_base_model()

# next we unfreeze all adapter layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to unfreeze the adapter weights? They don't get frozen in model.freeze_base_model()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do get frozen in model.freeze_base_model() (adapter attention weights are part of it)

@@ -579,12 +632,24 @@ def remove_special_characters(batch):
cache_dir=model_args.cache_dir,
config=config,
use_auth_token=data_args.use_auth_token,
ignore_mismatched_sizes=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

@@ -1194,6 +1194,19 @@ def _get_adapters(self):

return adapter_weights

def init_adapter_layers(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@patrickvonplaten
Copy link
Contributor Author

Added a test. Moved the code into a new example file. Added an extensive README. WER for a quick 10min run can be as low as 23% WER!

@patrickvonplaten patrickvonplaten merged commit 1609a43 into main Jun 14, 2023
@patrickvonplaten patrickvonplaten deleted the add_mms_ctc_fine_tuning branch June 14, 2023 23:10
@patrickvonplaten
Copy link
Contributor Author

novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Add mms ctc fine tuning

* make style

* More fixes that are needed

* make fix-copies

* make draft for README

* add new file

* move to new file

* make style

* make style

* add quick test

* make style

* make style
@dash8x
Copy link

dash8x commented Jul 17, 2023

In which release will this be available in?

@sanchit-gandhi
Copy link
Contributor

You can find the examples scripts here: https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition#connectionist-temporal-classification-with-adapters

They assume that you are running from the latest dev version:

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.32.0.dev0")

Which you can do by following the instructions for installing from source or editable install here: https://huggingface.co/docs/transformers/installation#install-from-source

Although for MMS ASR fine-tuning, you can safely run the script using the latest PyPi release version (4.31.0).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants