Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate torch.compile for the training pipeline #173

Closed
jackapbutler opened this issue Apr 18, 2023 · 9 comments
Closed

Investigate torch.compile for the training pipeline #173

jackapbutler opened this issue Apr 18, 2023 · 9 comments
Assignees
Labels
enhancement New feature or request investigation Research new functionality priority: medium work package: model training Relates to the model training pipeline

Comments

@jackapbutler
Copy link
Collaborator

jackapbutler commented Apr 18, 2023

Summary

Investigate if we can use torch.compile for our use case (full finetuning + peft models) through the Hugging Face Trainer.

  • Our current requirements allow for pytorch == 2.0.0 which means we can potentially take advantage of newer compilation techniques in the library.
  • This task would consist of investigation the benefits of this functionality for our use case, comparing speed improvements at different model sizes before finally integrating it into our pipeline configuration depending on the results.
  • It is totally valid to work with smaller model sizes (< 1B) locally.

Models

We are currently using the Pythia suite of models on Hugging Face which all use a context length of 2048 and have various sizes available here.

Outputs

  • Can we use torch.compile with our models?
  • If so, how much faster is the forward pass / model training with it?
    • nvidia-smi, standard Python profilers and/or tensorboard might be helpful tools to understand this behaviour

This can be broken into subtasks as required

@jackapbutler jackapbutler added work package: model training Relates to the model training pipeline investigation Research new functionality priority: medium enhancement New feature or request labels Apr 18, 2023
@othertea
Copy link
Contributor

I'll look into this!

@jackapbutler
Copy link
Collaborator Author

jackapbutler commented May 2, 2023

Hey @othertea,

  1. Regarding the training code, the only script you should need is chemnlp/experiments/scripts/run_tune.py. The associated training configurations are in chemnlp/experiments/configs and you can test with the smaller models (160M & 410M parameters configs) because we’re mainly interested in benchmarking the speed of the forward pass. You should be able to run that script locally in the repository Python environment but let me know if something comes up.
  2. Regarding the dataset, we’re currently cleaning the chemrxiv dataset further and it's also quite large so it might be easiest for you to
    a. take a smaller NLP dataset like a subset of ELI5
    b. tokenise it with the experiments/data/prepare_hf_chemrxiv.py (changing line 16 + 17)
    c. then run this through the pipeline (changing the data -> path: key in the configurations)

@jamesthesnake
Copy link

I'm gonna try and work on this

@othertea
Copy link
Contributor

othertea commented May 3, 2023

Hey @jackapbutler, I started looking into this:

Setup:

  • I worked off of commit 6872587.
  • I am using transformers=4.27.3 and torch=2.0.0+cu117.
  • I started looking into this before your comment, so I arbitrarily decided to use the arXiv subset of CarperAI/pile-v2-small-filtered, processed with prepare_hf_chemrxiv.py. Its train split contains 1000 rows.
  • For an even smaller model, I used the EleutherAI/pythia-70m model by modifying line 9 of 160M_full.yml.
  • I tested torch.compile via the torch_compile argument in TrainingArguments. Note that calling torch.compile() on the model before passing it to Trainer does not work.

Preliminary results from a few run_tune.py runs:

  • Speed:
    • For 1 epoch, using torch.compile had minimal effect. On average it was faster by 2%, but that's well within noise over just a few runs.
    • For 10 epochs, using torch.compile was faster by ~30%.
    • Since this is a small toy dataset, I expect the latter case to be more representative of chemnlp's settings and that torch.compile will help with training time when running the real setup. But more experimentation is needed.
  • Memory:

My suggested next steps:

  • More investigation into the higher memory usage.
  • Trying out torch.compile on the chemrxiv dataset. More generally, we should get a sense of whether torch.compile will help as we scale our models -- my guess is yes.
  • More repeated runs for higher confidence in results.

Additional Notes:
I also tried out 160M_ptune.yaml, but I ran into the following error:

  File ".../transformers/models/gpt_neox/modeling_gpt_neox.py", line 220, in _attn
    attn_scores = torch.where(causal_mask, attn_scores, mask_value)
RuntimeError: The size of tensor a (2048) must match the size of tensor b (2058) at non-singleton dimension 3

Let me know what you think!

@jackapbutler
Copy link
Collaborator Author

jackapbutler commented May 4, 2023

Hey @othertea, thank you so much, this looks great! Just a few comments / questions:

For 1 epoch, using torch.compile had minimal effect. On average it was faster by 2%, but that's well within noise over just a few runs.

Did you check the GPU utilisation (using something like nvidia-smi) over this training run as I imagine it's possible the model is performing computation faster and might be bottlenecked by data-loading speeds?

Using torch.compile increased memory usage from 15GB to 19GB.

The extra memory overhead is a bit concerning as that's our primary bottleneck for training the larger models on the cluster, do you have the hardware capacity to trial a slightly larger model such as 160M or 410M? Additionally there might be other reasons this is happening so I agree it would be good to investigate further.

I also tried out 160M_ptune.yaml, but I ran into the following error:

This looks unrelated (seems prompt tuning adds 10 extra learnable tokens but the model still expects 2048), I wouldn't worry about this for now.

Finally, just another idea but using tensororboard to profile the NN operators would help us identify which operators are actually faster under the new setup and remove any noise outside of the actual forward / backward pass.

@othertea
Copy link
Contributor

Some updates for you, @jackapbutler: I tried some bleeding-edge versions of packages, and now using torch.compile() takes less time to train without using more memory. In fact, it uses a little bit less memory!

Setup:

  • I switched to installing transformers from source: I installed with pip from commit ef42c2c.
  • I installed a nightly PyTorch, version 2.1.0.dev20230506+cu117.
  • I had to do some ad hoc fiddling with packages and libraries to make everything work, so unfortunately the above may not be sufficient characterization of the changes required to resolve the memory issue.
  • This time, I also tested the pythia-160m model.

Results:

  • Memory
    • For the pythia-160m model, torch.compile() reduced memory usage by ~0.8GB.
    • For the pythia-70m model, torch.compile() reduced memory usage by ~1.5GB.
    • It would be nicer if the memory savings did not reduce with increased model size, but the 160m model is as large as I could test.
  • Speed
    • For pythia-160m model, over 5 epochs, torch.compile reduced training time by ~30%.
    • For the pythia-70m model, over 10 epochs, torch.compile reduced training time by ~30%. In terms of absolute times, both compiled and not-compiled versions took a little longer than the last time I ran it. I'm not sure why -- this may be due to hardware (VM) differences, or due to the package updates, or something else. But the compiled version from this time still ran faster than the not-compiled version from last time and used less memory, so it should still be strictly better.
  • I checked GPU utilization for the 1 epoch as you suggested, and it did indeed seem lower for the torch compiled version.
  • Something I noticed is that when using the torch compiled version prints duplicates of this warning messages
    torch._inductor.utils: [WARNING] DeviceCopy in input program 
    
    which does not appear without torch.compile. I think this is another one of those issues that will be ironed out by the PyTorch or HuggingFace team.

My conclusion is that it might already be worth trying out torch.compile() with some bleeding-edge packages for the standard chemnlp setup, especially single-node experiments. Even if you don't want the nightly versions, I expect the relevant changes won't take too long to be available via the standard package managers.

Let me know what you think!

@jackapbutler
Copy link
Collaborator Author

jackapbutler commented May 15, 2023

Hey @othertea, this looks great, feel free to open a pull request adding compilation to the pipeline's TrainerConfig configuration here and I'll test it out on the Stability cluster to check it works on that environment as well 👍

@othertea
Copy link
Contributor

Sounds good @jackapbutler! The PR is here: #251

@jackapbutler
Copy link
Collaborator Author

this is done 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request investigation Research new functionality priority: medium work package: model training Relates to the model training pipeline
Projects
None yet
Development

No branches or pull requests

3 participants