Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recovering from crashed run #74

Open
versae opened this issue Aug 11, 2022 · 3 comments
Open

Recovering from crashed run #74

versae opened this issue Aug 11, 2022 · 3 comments

Comments

@versae
Copy link

versae commented Aug 11, 2022

Hi, thanks for these collection of scripts!

I've been trying to run your run_flax_speech_recognition_ctc.py on a single TPUv3-8 but after a few epochs I tend to always run out of memory (not sure if caused by memory leak or something). I also tried to recover from the last checkpoint by skipping the number of steps the model was last saved at, and setting the learning rate appropriately. I also tried modifying MixedPrecisionTrainState.create() to it starts at the last saved checkpoint step too. Nothing worked. As soon as it starts training, it runs out of memory again. Any idea of what could be happening?

Thanks!

@sanchit-gandhi
Copy link
Owner

Hey @versae! Glad to hear you're enjoying using these scripts!

Hmmm, that's very interesting! The closest thing I've seen to that is when I greatly reduced the pad_input_to_multiple_of value down to <16000. There, I got OOM's after 5-10k optimisation steps. I presumed here it was the number of binaries increasing (bucketing the inputs into more granular chunks), but didn't dig into it too deeply.

Do you have an example script I could use to replicate? I have a v3-8 sitting idle that I could use to emulate this behaviour!

Utils for properly loading model weights and optimiser states from saved checkpoints are definitely two things that needs to be added! We can probably look to Dalle-mini for help on this: https://github.com/borisdayma/dalle-mini/blob/fc83bc9280772e475946a1b258fe10eba3e5ab8f/tools/train/train.py#L1131

@versae
Copy link
Author

versae commented Aug 24, 2022

Thanks for the quick reply (much quicker than mine now). I use the default pad_input_to_multiple_of value of 32000. The OOM didn't occur until epoch 14/40 of a fairly big dataset (~740k steps). I also tried filtering out audios of different lengths. Still OOM errors very late into the training. Here's an example repo with a crashed run: https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu/settings.

Boris' restore state is probably the way to go. In my modified training script, I thought I could achieve the same by using skip_steps and setting the learning rate to the last value before crashing. But it does not work :'(

@sanchit-gandhi
Copy link
Owner

Ah that's really frustrating! Sorry to hear it happened so late into training :/

What happens when you try to correct for the LR? Does the loss explode? Saving the optimiser states could help here (rather than re-initialising the momentum terms).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants