-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MMS CTC Fine-Tuning #24281
Add MMS CTC Fine-Tuning #24281
Conversation
…to add_mms_ctc_fine_tuning
The documentation is not available anymore as the PR was closed or merged. |
@@ -579,12 +632,24 @@ def remove_special_characters(batch): | |||
cache_dir=model_args.cache_dir, | |||
config=config, | |||
use_auth_token=data_args.use_auth_token, | |||
ignore_mismatched_sizes=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed when instantiating from CTC checkpoints
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
@@ -31,6 +31,7 @@ | |||
import numpy as np | |||
import torch | |||
from datasets import DatasetDict, load_dataset | |||
from safetensors.torch import save_file as safe_save_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a required dependency so should be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should go in its own example instead of adding some more code to the (already complex) ctc example. It's preferable to have multiple examples focused on one thing than one big multi-purpose example.
Ok for me |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding! Changes all LGTM (filled in examples conditional ;))
+1 to @sgugger's suggestion of having a separate example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good already, thanks for the updates @patrickvonplaten. Just some minor suggestions
adapter_attn_dim: int = field( | ||
default=None, | ||
metadata={ | ||
"help": "If defined, adapter layers will be randomely initialized and the rest of the model will be frozen." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"help": "If defined, adapter layers will be randomely initialized and the rest of the model will be frozen." | |
"help": "If defined, adapter layers will be randomly initialized and the rest of the model will be frozen." |
adapter_language: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice help message!
@@ -132,6 +134,12 @@ class ModelArguments: | |||
ctc_loss_reduction: Optional[str] = field( | |||
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} | |||
) | |||
adapter_attn_dim: int = field( | |||
default=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this default to some sensible value or should we always force the user to pass it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes good point, now that things will be moved to a new file, I'll set a good default
# first we freeze the whole base model | ||
model.freeze_base_model() | ||
|
||
# next we unfreeze all adapter layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to unfreeze the adapter weights? They don't get frozen in model.freeze_base_model()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They do get frozen in model.freeze_base_model()
(adapter attention weights are part of it)
@@ -579,12 +632,24 @@ def remove_special_characters(batch): | |||
cache_dir=model_args.cache_dir, | |||
config=config, | |||
use_auth_token=data_args.use_auth_token, | |||
ignore_mismatched_sizes=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
@@ -1194,6 +1194,19 @@ def _get_adapters(self): | |||
|
|||
return adapter_weights | |||
|
|||
def init_adapter_layers(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
…to add_mms_ctc_fine_tuning
…ace/transformers into add_mms_ctc_fine_tuning
Added a test. Moved the code into a new example file. Added an extensive README. WER for a quick 10min run can be as low as 23% WER! |
Demo training run: https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-mms-demo |
* Add mms ctc fine tuning * make style * More fixes that are needed * make fix-copies * make draft for README * add new file * move to new file * make style * make style * add quick test * make style * make style
In which release will this be available in? |
You can find the examples scripts here: https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition#connectionist-temporal-classification-with-adapters They assume that you are running from the latest dev version: transformers/examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py Lines 55 to 56 in f104522
Which you can do by following the instructions for installing from source or editable install here: https://huggingface.co/docs/transformers/installation#install-from-source Although for MMS ASR fine-tuning, you can safely run the script using the latest PyPi release version (4.31.0). |
What does this PR do?
This PR adds language adapter fine-tuning for MMS. Still playing around with good hyper-parameters but script is functional.
Getting some very nice results now for:
WER drops to 25% just after 200 steps.
See: https://wandb.ai/patrickvonplaten/huggingface/runs/6f5cx5gg?workspace=user-patrickvonplaten
@sgugger @amyeroberts @sanchit-gandhi it'd be super nice to get a quick review here whether the code changes are generally fine with you. I'll only have to fill out the TODOs in the README with a nice example code and some description.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.