This repository contains official implementation for the paper "How Does Critical Batch Size Scale in Pre-training?"
We find that CBS scales primarily with dataset size, not model size. For practitioners, this means that increasing batch sizes beyond the CBS does not yield significant efficiency gains and that focusing on dataset size is more beneficial for optimizing data parallelism. The findings may offer practical insights into efficient scaling strategies for pre-training large language models.
Please see our Blog Post for a high-level overview of the paper.
First install PyTorch according to the instructions specific to your operating system.
To install from source, run:
pip install -e .[all]
We use Dolma to download and preprocess the data.
cd dolma
pip install -e .
To download and preprocess the data, run:
python scripts/data/download.py
bash scripts/data/tokenizer.sh
We use transformer-based auto-regressive models in different sizes. The configurations can be found in configs/critical-bs/${MODEL_SIZE}.yaml
directory.
Configurations to reproduce various experiments can be found in configs/critical-bs
directory.
We use wandb sweeps for hyperparameter search (please refer to our paper for hyper-parameter details). To run a sweep, use the following command:
wandb sweep configs/critical-bs/sweeps/151M_sweep.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 wandb agent cbs/cbs/$SWEEP_ID
We disable checkpointing by default. To enable checkpointing, got to scripts/train.py and uncomment some lines related to checkpointing.
If you find this repo useful, please consider citing:
@misc{zhang2024cbs,
title={How Does Critical Batch Size Scale in Pre-training?},
author={Hanlin Zhang and Depen Morwani and Nikhil Vyas and Jingfeng Wu and Difan Zou and Udaya Ghai and Dean Foster and Sham Kakade},
year={2024},
eprint={2410.21676},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.21676},
}