Skip to content

Implementation of Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting

Notifications You must be signed in to change notification settings

Equationliu/Kangaroo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Kangaroo

 Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting

| Arxiv Paper |

Version Contributions welcome


Drawing inspiration from early exiting, we propose a novel self-speculative decoding framework Kangaroo, which uses a fixed shallow sub-network as a self-draft model, with the remaining layers serving as the larger target model. We train a lightweight and efficient adapter module on top of the sub-network to bridge the gap between the sub-network and the full model’s representation ability. The adapter network consists of only one multi-head attention and two normalization layers. Surprisingly, we find this simple design efficient but powerful. To further reduce the inference latency of the self-draft model, we introduce an additional early exiting mechanism for generating draft tokens, aiming to avoid unnecessary costs on more difficult tokens.

TODO List

  • inference code & checkpoints of Kangaroo.
  • code for training Kangaroo.
  • tree verification.
  • bsz > 1 and decoding with sampling.

Training

We follow the training procedure of Medusa and Eagle.

  1. data preprocess
cd data
python allocation.py --outdir /home/ma-user/work/Data/
  1. training
python start_train.py

Inference

## Vicuna-7B as an example

## Vanilla decoding
CUDA_VISIBLE_DEVICES=0 python -m evaluation.inference_baseline --model-path "/cache/CKPT/vicuna-7b-v1.3" --model-id "vicuna-7b-v1.3-vanilla-float16-temp-0.0" --bench-name "Kangaroo" --temperature 0.0 --dtype "float16"

## Kangaroo
CUDA_VISIBLE_DEVICES=0 python -m evaluation.inference_kangaroo --adapter-path "/cache/CKPT/kangaroo-vicuna-7b-v1.3" --exitlayer 2 --model-path "/cache/CKPT/vicuna-7b-v1.3" --threshold 0.6 --steps 6 --model-id "vicuna-7b-v1.3-kangaroo-thres-0.6-steps-6-float16" --bench-name "Kangaroo" --dtype "float16"

To get the detailed speed information, run python evaluation/speed.py.

The corresponding huggingface ckpts of kangaroo can be downloaded at Kangaroo Google Drive.

Citation

@article{liu2024kangaroo,
  title={Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting},
  author={Liu, Fangcheng and Tang, Yehui and Liu, Zhenhua and Ni, Yunsheng and Han, Kai and Wang, Yunhe},
  journal={arXiv preprint arXiv:2404.18911},
  year={2024}
}

Acknowledgements

We acknowledge the authors of

License

License: MIT

About

Implementation of Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages