This is a Jax/Flax reimplementation of GPT-2 family of models on FineWeb-Edu dataset, inspired from karpathy/build_nanoGPT.
Updates:
- Add support for
tf.data
pipelines over TFRecords. - Add support for
bfloat16
computation. - SPMD (multi-node) training support using
pmap
. - Expose configurables via CLI flags (or config dict).
- Use cuDNN flash attention kernel (SDPA API) (jax-ml/jax#22546).
-
nn.Embed
typecast performance issue. - Use scale init for residual paths.
- Fix large gradient norm spikes for longer training runs.
- Test
accumulate_gradient
. - Update docstrings.
- Add
shard_map
support for model and data sharding. - KV cache decoding.
Create a virtual environment and install packages.
pip install -r requirements.txt
For SPMD support (multi-node training), install OpenMPI.
sudo apt install openmpi-bin openmpi-doc libopenmpi-dev
# Also set the same `data_dir` under `configs/default.py`
python fineweb.py --outdir ./data
# Single process, multi-GPU.
python train.py --workdir artifacts/gpt2_124M --config configs/default.py
# multi-process on same host using OpenMPI.
mpirun -n 8 \
-bind-to socket \
python train.py --workdir artifacts/gpt2_124M --config configs/default.py
# multi-node across 8 hosts (needs passwordless SSH across hosts).
mpirun -n 8 \
-pernode \
-H hostname1,hostname2,...,hostname8 \
-bind-to socket \
python train.py --workdir artifacts/gpt2_124M --config configs/default.py
MIT