Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add LayerSkip to AO #633

Open
jcaip opened this issue Aug 8, 2024 · 4 comments
Open

[RFC] Add LayerSkip to AO #633

jcaip opened this issue Aug 8, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@jcaip
Copy link
Contributor

jcaip commented Aug 8, 2024

Tracker issue for adding LayerSkip to AO.

This is a training and inference optimization that is similar to layer-wise pruning. It's particularly interesting for LLM inference because it combines very cleanly with speculative decoding to provide up to a 1.86x speedup.

@mostafaelhoushi is interested in adding this to torchtune and is interested in upstreaming a subset of the code to ao. See here for more details. In particular, he's interested in doing this without having to alter the module definition.

This is attractive because this part of LayerSkip is not unique to LLMs and can be used for other models. (@mostafaelhoushi to fill out with relevant results).

What is being proposed:

for LayerSkip there is a training recipe and there is an inference recipe:

  • training recipe:
    • layer dropout: This is skipping layers sotchastically during training. this is what I think we can get into torch.ao, because it could benefit all types of models, transformers, CNNs, vision, text, etc. It can speedup training and might improve accuracy.
    • early exit loss: This could also be added to torch.ao and it could help different modalities, but my require more time.
  • inference recipe:
@gau-nernst
Copy link
Collaborator

Layer dropout during training looks like some form of Stochastic Depth. Some related implementations

A glance at LayerSkip paper suggests that they mask each sample independently in a batch. Probably need some tricks to see speedups? The torchtune PR implements it by indexing, applying the function, and writing back subset of a batch. Curious to see if the extra overhead is outweighed by less computation during training.

@jcaip
Copy link
Contributor Author

jcaip commented Aug 8, 2024

Yup, the layer dropout aspect of layer skip is basically a version of stochastic depth, that's part of the reason why I'm interested in having it in AO, since a generic stochastic depth function / module would be useful outside of just LLMs.

IIRC when talking to mostafa he is faster when masking + rewriting but the speedups mostly come from the self-speculative decoding part of the technique.

@mostafaelhoushi can you share some benchmarks about the layer dropout implementation specifically when you update the issue? Thanks.

@msaroufim msaroufim added the enhancement New feature or request label Aug 9, 2024
@mostafaelhoushi
Copy link

mostafaelhoushi commented Aug 15, 2024

Sorry for the delay from my side.

Other Papers

I would like to mention other papers or models that used layer dropout (aka stochastic depth):

  • Vision Models:
    • It was first explored in ResNets by Huang et al, 2016.
    • ConvNext uses it in its training recipe. It uses higher layer dropout rates for larger models 0.1/0.4/0.5/0.5 for ConvNeXt-T/S/B/L respectively trained on ImageNet. However, when training on the larger ImageNet-22K it uses smaller layer dropout rates: 0.0/0.0/0.1/0.1/0.2.
    • Layer dropout is also commonly used in vision transformers. Swin Transformers use higher layer dropout rates for larger models: i.e., 0.2, 0.3, 0.5 for Swin-T, Swin-S, and Swin-B, respectively.
    • Dinov2 also used layer dropout (Cc @danthe3rd) when training.
  • NLP Models:
    • LayerDrop increased accuracy of RoBERTa and machine translation Transformer models by applying dropout to every other transformer layer and increased its robustness during inference when removing layers. A dropout rate of 0.2 was used, and it was recommended to use a higher rate, 0.5, for smaller models.
    • Progressive Layer Dropping increased the pretraining speed of BERT by 1.86x by applying a dropout rate that progressively increases every iteration across time and every layer across the model, with a maximum dropout rate of 0.5. The paper found that layer skipping was robust to higher learning rates, which was one of the causes of the training speedup.

Other Implementations

  • Regarding per batch vs per sample dropout, we can implement both. However, I would like to mention that
    • when I tried measuring speedup for both I found them to be similar.
    • per batch dropout causes errors when training with FSDP. This is because FSDP likes to find gradients for all modules in a model. Per batch layer dropout will cause a module not to have a forward pass nor a backward pass at an iteration, and this causes FSDP to throw an error. There are mitigations for this, e.g., to overwrite the backward pass (as done here in fariseq2).
  • My understanding that the timm implementation doesn't lead to speedup as it replaces samples with zeros rather than skipping compute, right?

Benchmark Results

On TorchTune, I ran this command on a single A100 GPU

$ tune run --nproc_per_node 1 full_finetune_distributed --config llama3/8B_full output_dir=$CKPT_PATH checkpointer.checkpoint_dir=$CKPT_PATH/original checkpointer.output_dir=$CKPT_PATH tokenizer.path=$CKPT_PATH/original/tokenizer.model batch_size=16

and got these measurements:

Maximum Dropout Dropout Scale Across Layers Time to Reach 50 Iterations Speedup
None 01 min 32 sec 1x
0.2 Uniform 01 min 23 sec 1.07x
0.3 Uniform 01 min 17 sec 1.19x
0.5 Uniform 01 min 05 sec 1.42x
0.5 Linear. TBD TBD
0.2 Exponential 01 min 30 sec 1.02x
0.5 Exponential 01 min 22 sec 1.12x

I also want to tag @danthe3rd as he guided me to implement the per-sample layer dropout and he has implemented it for Dinov2.

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* attempt 1

* updates

* try with cache

* fix

* fix

* fix

* update

* fix

* test

* clean up

* more changes

* more changes

* more changes

* more changes

* more changes

* more changes

* more changes

* more changes

* more changes

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* add commit hash

* Pin to the latest stable ExecuTorch commit.

* Pin to the latest stable ExecuTorch commit.

* update

* update

* change pin

* updates

* updates

* update pin

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
@Xynonners
Copy link

https://arxiv.org/abs/2402.17812 DropBP may be an option too (only skip layers in the backward, they claim it leads to better accuracy).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants