Skip to content

weigao266/fairseq-CO2

 
 

Repository files navigation

Fairseq-CO2

This repository shows an example of utilizing CO2 within Fairseq.


Requirements and Installation

  • PyTorch version >= 1.10.0
  • Python version >= 3.8
  • For training new models, you'll also need an NVIDIA GPU and NCCL
  • To install fairseq-CO2 and develop locally:
git clone https://github.com/weigao266/fairseq-CO2.git
cd fairseq-CO2
pip install --editable ./
  • The implementation of CO2 is integrated in Fairscale at fairscale-CO2. To install fairscale-CO2 and develop locally:
git clone https://github.com/weigao266/fairscale-CO2.git
cd fairscale-CO2
pip install --editable ./
  • For faster training install NVIDIA's apex library:
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
  --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
  --global-option="--fast_multihead_attn" ./
  • For large datasets install PyArrow: pip install pyarrow
  • If you use Docker make sure to increase the shared memory size either with --ipc=host or --shm-size as command line options to nvidia-docker run .

Usage

Run the script run_co2_local.sh to train a GPT-2 (Medium) model with 355M parameters, using CO2. The script sets co2_base_algorithm = localsgd, and co2_outer_momentum = 0.2.

cd co2_examples
bash run_co2_local.sh

Citation

If you find our work useful, please cite the following paper:

@article{sun2024co2,
  title={CO2: Efficient Distributed Training with Full Communication-Computation Overlap},
  author={Sun, Weigao and Qin, Zhen and Sun, Weixuan and Li, Shidi and Li, Dong and Shen, Xuyang and Qiao, Yu and Zhong, Yiran},
  journal={arXiv preprint arXiv:2401.16265},
  year={2024}
}

About

Example of using CO2 within Fairseq.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.1%
  • Other 1.9%