-
Notifications
You must be signed in to change notification settings - Fork 26.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax/Flax pretraining of wav2vec2 #19588
Comments
Hey @aapot! Cool to see you're trying out pre-training of Wav2Vec2 in Flax on Finnish 🇫🇮 Indeed, the script is under 'research projects' as it remains unsolved. Pre-training of Wav2Vec2 models is notoriously difficult due to issues with stability, giving rise to the phenomena you've experienced such as code vector collapse and unstable contrastive loss. AFAIK there isn't a working implementation for training Transformers W2V2 models in Flax, which makes it an interesting topic to pursue! You've done a great job at digging through issues and PRs to find the aforementioned points! Both the points you've raised look to be missing from the Flax Wav2Vec2 script. Did you try gradient scaling in your experiments? One thing we can try is running the PyTorch and Flax scripts step-by-step in parallel and inspecting where they diverge. We can do this with a tiny dummy model ( |
Just for reference, I never got Wav2Vec2 to work in JAX, but it should def be possible (didn't spent too much time on it) |
@sanchit-gandhi yep this would be interesting to get working! Yes, I also tried gradient scaling like it was implemented in the PyTorch pretrain script (basically multiply gradients with (num devices / total samples)) without luck. I'd be interested in putting some time into fixing this so feel free to provide further pointers. Training of these ASR and NLP models for Finnish is a free time hobby project with @R4ZZ3 so cannot promise anything yet but let's get this fixed 🤗 |
Awesome @aapot, that's great to hear! Essentially what you want to do is run the PyTorch script and Flax script with identical args (for the model, data and training args). In doing this, the PyTorch and Flax models should receive identical inputs, and thus should compute identical losses if the training scripts are the same. What you want to then do is compare the outputs of the PT and Flax training scripts after each step of pre-training:
It's likely the bug in the Flax script lies in 3, 4 or 5! Once you identify where the losses deviate, you can dig deeper into the code for PT and Flax and try to find the line(s) of code where the functionality is different. How you debug this is up to you. To make this quick and easy, I'd recommend using a dummy model ( For comparing the outputs, you can either run the scripts side-by-side and print intermediate values, or combine them into a single notebook and print cell outputs after each step (I can give you a template for this if you want to use a notebook). Print statements are easy to use, but don't provide much detail other than numeric values. What I'd do is first add print statements for each of the items listed in 1-5 to quickly see which values match ✅ and which values don't ❌. After that you can go deeper with either: more print statements, breakpoints (ipdb), or a debugger. It might take a bit of time to establish a good set-up for debugging quickly, but once you've got this set-up it should be a case of finding where the losses are different and then fixing for Flax! You might also need to disable shuffling of the training dataset to make sure the training inputs are passed in the same way to PT as Flax. These should make for good starting points (haven't tried them, but they're similar to the configs I use for debugging ASR fine-tuning): PT
Flax
|
Thanks for those pointers @sanchit-gandhi, sounds reasonable! I'll start digging into this soon, will keep you updated here. |
Hi,
I tried pre-training the jax wav2vec2 model on my own data and I came across similar problems. Tried with multiple huge chunks of my own dataset and the perplexity always collapsed to 2 while the loss fluctuated a lot. I also noticed that across all my datasets the eval loss was always 0.09969. So, if I finetune this pretrained model, will it give any good results? Also do you guys have any code to fine-tune this pretrained model that I can use? |
For finetuning we have used these resources as base: Also we are trying out going to try out these. We just need to fix some of our datasets before that as we have lover case material. Luckily we have trained T5 model for casing + punctuation correction. https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition#sequence-to-sequence |
Hey @Aaryan369! Thanks for sharing your experience - it seems like there's an inherent bug in the JAX pre-training implementation with how the loss terms are computed leading to code vector perplexity collapse and unstable loss. You can certainly try fine-tuning a pre-trained Wav2Vec2 model. If your fine-tuning data is in-domain with the pre-training you can expect good results with very little data - as little as 10 minutes as shown by the Wav2Vec2 paper! If your fine-tuning data is more out-of-domain with the pre-training data, you can expect to require much more data to achieve good results. This is really on a case-by-case basis, so you'll have to make that decision based on what you know about your fine-tuning situation! In terms of pre-trained models, there are English-only checkpoints: And multilingual ones (https://huggingface.co/facebook/wav2vec2-large-xlsr-53 for example). The English-only ones will fare better for English speech tasks, and the multilingual ones for most others. The resources @R4ZZ3 has kindly linked are perfect for fine-tuning in PyTorch. If you want to fine-tune in JAX, I'd advise you to try: https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/run_flax_speech_recognition_ctc.py This script closely resembles the PyTorch one in Transformers: https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition It's on my list to add this JAX CTC fine-tuning script to Transformers over the coming weeks! |
We are compiling a large speech corpus in Norwegian (100k+ hours). We expect it to be ready in roughly a month. Our plan is to pretrain a Wav2Vec2. We have access to TPUs through TRC and ideally we would like to train this in Flax instead of XLA/PT. This is a high priority project for us, and I am happy to assist in both testing and debugging here. |
100k is mega! Very excited to see how pre-training JAX Wav2Vec2 in Finnish goes with this much data. Just out of interest, are you set on producing a pre-trained checkpoint in Finnish? Or is the end goal downstream ASR? The multilingual Whisper models are pre-trained on 1066h of labelled Finnish audio-transcription data (out of 670,000h total). They get good results with zero-shot transfer learning (i.e. no fine-tuning) on Finnish Common Voice 9 (17.0% WER) and Finnish VoxPopuli (15.5% WER), c.f. Tables 11 and 12 from the Whisper paper. You could definitely improve upon these results with fine-tuning! Might be a faster route to a performant, downstream Finnish ASR model than Wav2Vec2 pre-training + fine-tuning? |
Just a quick update so I finally had time to start the actual debugging. More info to follow soon. |
Great! Keep us posted! |
Alright, here are some findings so far: I have found couple differences between Flax and PT model code so far:
y_soft = nn.softmax((hidden_states + gumbels) / temperature)
index = y_soft.argmax(axis=-1)
y_hard = jnp.zeros_like(hidden_states).at[jnp.arange(len(hidden_states)), index].set(1.0)
codevector_probs = y_hard - y_soft + y_soft when the PT code looks like this: y_soft = gumbels.softmax(dim)
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft At first, I also had the PT's In addition, I have made initial updates (some more smaller updates could still be made) to the |
Continuing with the updates: |
Really great work, @aapot. I do however understand that there are still some issues here (since the contrastive loss starts to increase after a while), and that the issue most likely is related to the Flax gumbel implementation. Any chance that anyone at 🤗 can take a look at that? What do you think @sanchit-gandhi @patrickvonplaten ? When this is done Ill be glad to contribute with larger training, and finetuning/testing on downstream tasks. |
You could hack into the code and load pre-trained weights! I'd recommend the checkpoint at https://huggingface.co/hf-internal-testing/tiny-random-wav2vec2 from transformers import Wav2Vec2ForPreTraining
model = Wav2Vec2ForPreTraining.from_pretrained("hf-internal-testing/tiny-random-wav2vec2") JAX: from transformers import FlaxWav2Vec2ForPreTraining
model = FlaxWav2Vec2ForPreTraining.from_pretrained("hf-internal-testing/tiny-random-wav2vec2", from_pt=True) => this will initialise the models with the same weights! From the PyTorch code, it seems as though we should break Sounds like you're making good progress @aapot! Keep us posted with updates and questions, happy to help! |
Oh one more question! You're running both on CPU right? JAX will definitely diverge from PT on GPU/TPU due to differences in the matmul precision (c.f. jax-ml/jax#10413 (comment)) |
Thanks for the tips @sanchit-gandhi! Actually I also had in mind to use pre-trained weights to compare model outputs that way, will try it soon. Will also check fairseq implementation if that could reveal more stuff to fix. Yup, I am running both Jax and PT on my local laptop with CPU when debugging. |
Okay great! Tiny pre-trained models on CPU is the way to go here! |
After using pre-trained weights to continue pretraining for one more step with same input data, I think following is happening with model outputs:
In addition, I checked the fairseq gumbel softmax implementation and they are also using the PyTorch's If someone could verify if using |
Fantastic work @aapot! I noticed the following comment in the pull from @patrickvonplaten to @ThomAub: May there be issues here? |
This is exactly the way we want to compute differences between PT and Flax in the projected states space 👌 For reference, a matching implementation should have a max abs difference of 1e-5.
This seems logical! What I would do is dive into the GumbelVectorQuantizer and check the intermediate variables up to where the randomised sampling is performed. If they match up until the sampling that's a good sign. Forcing sampling between PT and Flax to be the same is a bit tricky... IMO we have two options:
Unfortunately, both of these methods are a bit hacky. The first might be easier IMO - you don't have to define it to be anything too crazy, and just 2 or 3 different matrices would do (we just need to verify the outputs are the same over 2 or 3 training steps).
Sounds like we're narrowing down! Maybe we can try forcing the same Gumbel quantiser outputs and then experiment with / without stop gradient. The fact that fairseq and HF PT use |
Good point @peregilk - worth having a look to see if there are any OSS implementations of the Gumbel ops in JAX/Flax online! (as far as I'm aware there's not, but might be wrong!) |
Please keep this issue open. It is still activity going on for solving this issue. |
Hope the analysis is going ok @aapot, think you're doing a great job here! Feel free to share any updates / ask questions, more than happy to help! |
Hi @sanchit-gandhi, unfortunately I have been very busy the past month so haven't had time to investigate more about this jax gumbel quantizer. Now that the recent Hugging Face Whisper finetuning event is over (where I participated too), I'll get back to debugging this wav2vec2 pretraining after a short Christmas break :) In any case, I am planning to create PR of my current work even if the Gumbel quantizer would not get fixed because my current branch has pretty much updated the Wav2vec2 flax model and pretraining code implemetation up to date with the Pytorch version. But I hope we get the Gumbel part fixed too. |
Hey @aapot! Hope you had a nice Xmas break and that you enjoyed the Whisper event 🙂 Thanks for the update! Sounds good regarding opening a PR with the current changes - these are certainly welcome fixes! We can iterate on the PR to see if we can get the Gumbel part fixed too. Feel free to ping me here or on the new PR with questions / queries - more than happy to help and excited to see this one to completion! |
@sanchit-gandhi quick update on the GumbelVectorQuantizer with the option 1 you mentioned earlier (replace gumbel sampled matrix with predefined matrix). Any ideas how to proceed? |
Hey @aapot! Thanks for the update - really cool to see the progress you're making here! Sounds like you've got a nice system going for debugging and comparing the PT-FX outputs! That's great the From the last experiment, it sounds pretty likely that the nn.Module's are equivalent now between PT and Flax (we're getting the same tensors out when we override the Gumbel sampling step). I would suggest we quickly verify that all the loss terms are equivalent with this non-deterministic set-up. If they match for the first training step, that's perfect, it means we should have a closely matching implementation. Note that with our 'forced sampling' method, we can verify that we get the same losses between PT and Flax, but since we change how the code vectors are computed in Flax (by forcing the sampled Gumbel matrix) we can't expect the gradients to be correct - forcing the Gumbel sampling is going to mess-up the backprop, so anything after the first parameter update is going to be divergent. So once we've verified that all the loss terms are the same (contrastive, diversity, total loss), I would re-instate stochastic sampling of the Gumbel matrix in Flax and see whether we can train a stable system! How does that sound? |
@sanchit-gandhi sounds reasonable! |
Hey @aapot! Awesome - thanks for getting back with these results! Really enjoying hearing your updates here! Shall we double check that the code vector perplexity is being computed correctly? The diff for this value & the contrastive loss looks a little high for a dummy model (should be < 1e-5)! We can quickly check the code vector ppl function and verify that it matches PT (and correct if not!) |
Hey @aapot! We're really close here! Any chance you've had the opportunity to look into the code vector perplexity and the error propagation onto the contrastive loss? Once we're confident with these we can start scaling up to full training runs |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
There is Jax/Flax based script available for pretraining wav2vec2 here. I have been trying to pretrain new wav2vec2 model for Finnish on TPUs using that script but it seems impossible to get the model train properly. I know the script is under research_projects so I am wondering if anyone has been able to succesfully pretrain wav2vec2 models with it? Or if anyone has made own updates to the script to fix potential problems?
For me, it looks like the codevector_perplexity will always collapse to value of 2 and stay there which I believe is not a good thing. Also, the constrastive loss is usually very unstable. I attached the image below showcasing those issues. In addition, I have tried pretraining the wav2vec2 with the official fairseq implementation where the training looks to be working fine without those issues. So I believe HF Jax/Flax implementation is broken somehow.
Also, I think the HF Jax/Flax wav2vec2 implementation is not fully on par with the HF PyTorch wav2vec2 implementation. For example, I noticed this comment by @patrickvonplaten #14471 (comment) and I think the comment's point number 1 is not implemented in the Jax/Flax version. Also, on Pytorch wav2vec2 pretraining PR comment #13877 (comment) gradient scaling is implemented to avoid issues on multiple devices training. I wonder if same would be needed for Jax/Flax script when training on 8 TPU cores? I tried implementing those myself but then I found this script where @patrickvonplaten seemed to have already implemented the first point number 1: https://huggingface.co/patrickvonplaten/wav2vec2-german-flax/blob/main/run_wav2vec2_pretrain_flax.py
Anyhow, even with those potential fixes I haven't been able to get the training work properly. That's really pity since the Jax/Flax training would be really great when using TPUs.
The text was updated successfully, but these errors were encountered: