-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate torch.compile
for the training pipeline
#173
Comments
I'll look into this! |
Hey @othertea,
|
I'm gonna try and work on this |
Hey @jackapbutler, I started looking into this: Setup:
Preliminary results from a few
My suggested next steps:
Additional Notes:
Let me know what you think! |
Hey @othertea, thank you so much, this looks great! Just a few comments / questions:
Did you check the GPU utilisation (using something like
The extra memory overhead is a bit concerning as that's our primary bottleneck for training the larger models on the cluster, do you have the hardware capacity to trial a slightly larger model such as 160M or 410M? Additionally there might be other reasons this is happening so I agree it would be good to investigate further.
This looks unrelated (seems prompt tuning adds 10 extra learnable tokens but the model still expects 2048), I wouldn't worry about this for now. Finally, just another idea but using |
Some updates for you, @jackapbutler: I tried some bleeding-edge versions of packages, and now using Setup:
Results:
My conclusion is that it might already be worth trying out Let me know what you think! |
Sounds good @jackapbutler! The PR is here: #251 |
this is done 🚀 |
Summary
Investigate if we can use
torch.compile
for our use case (full finetuning + peft models) through the Hugging FaceTrainer
.pytorch == 2.0.0
which means we can potentially take advantage of newer compilation techniques in the library.Models
We are currently using the Pythia suite of models on Hugging Face which all use a context length of 2048 and have various sizes available here.
Outputs
torch.compile
with our models?nvidia-smi
, standard Python profilers and/ortensorboard
might be helpful tools to understand this behaviourThe text was updated successfully, but these errors were encountered: