Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BitNet b1.58 training #930

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Sep 24, 2024

This PR adds training code for BitNet b1.58 (ternary weights - 1.58 bit. The first version of BitNet is binary weights). This is implemented as tensor subclass and integrate nicely with the quantize_() API. I also added 2 extra optimizations:

  • Use INT8 tensor cores for forward pass: BitNet does row-wise abs-max scaling for INT8 activations and tensor-wise abs-mean scaling for ternary weights -> no brainer to use INT8 tensor cores
  • 2-bit FSDP all-gather: I follow float8 optimization, which also performs tensor-wise scaling for weights -> quantize weights to ternary and pack to 2-bit for FSDP all-gather. It's possible to pack to smaller bits (nearer to 1.58-bit limit), but for simplicity, I pack to 2-bit.

Not optimized for inference (yet). A good baseline for inference would be something like A8W2 kernel from GemLite

BitNet b1.58

BitNet b1.58 uses ternary weights: each parameter can only take on 3 distinct values {-1, 0, +1}, thus making a BitNet model very compact. BitNet uses tensor-wise abs-mean scaling for weights (quantize to ternary) and row-wise abs-max scaling for activations (quantize to INT8).

BitNet is originally trained with QAT: the weights and activations are fake-quantized, and straight-through estimator (STE) is used to calculate gradients with respect to floating point weights. This process adds extra overhead over standard training. Our implementation utilizes INT8 Tensor Cores to make up for this loss in speed. In fact, our implementation is faster than BF16 training in most cases.

Usage

from torchao.prototype.quantized_training import bitnet_training
from torchao import quantize_

model = ...
quantize_(model, bitnet_training())

Note: following the BitNet Training Tips, Code and FAQ, user should insert extra RMSNorm before each nn.Linear layers and also remove the original RMSNorm before attention and MLP modules. Calling quantize_(model, bitnet_training()) will NOT perform this for you. You can take a look at our example training script benchmarks/quantized_training/pretrain_llama2.py on how to do this for our Llama model.

When used with FSDP2 training, you can pre-compute BitNet weight scales for the next iteration to synchronize all scales with a single all-reduce operation. This should be done after optimizer step.

from torchao.prototype.quantized_training import precompute_bitnet_scale_for_fsdp

for _ in range(n_steps):
  model(inputs).sum().backward()
  optim.step()
  precompute_bitnet_scale_for_fsdp(model)

Results

Convergence check Ran with my experimental repo https://github.com/gau-nernst/quantized-training. Llama2-1.1B (based on TinyLlama) for 1B tokens on FineWeb-Edu using 1x 4090. Baseline is full BF16 training. Each step is 4x8192 tokens.

image
Model Train loss Train tok/s Hellaswag acc_norm
BF16 baseline 2.97 1300 0.3048
BitNet 3.05 1500 0.2953

Note: at this scale, we don't expect the loss curves to have a gap. According to Figure 2 of FAQ, the gap only appears at around 10B tokens.

Sanity benchmark with built-in training script Using benchmarks/quantized_training/pretrain_llama2.py. Llama2-1B on TinyStories, w/ 4090, 1k steps. Each step is 16x2048 tokens. PyTorch 2.4.0. Full BF16 training

image
Model Train loss tok/s
BF16 baseline 1.54 18,139
BitNet reference 1.52 17,510 (-3.5%)
BitNet ao (w/ INT8 tensor cores) 1.50 22,751 (+25%)

The train loss is a bit strange, but I think training on TinyStories is not so reliable. Perhaps it's just numerics. Side note that the speedup is impressive is because INT8 tensor cores is very fast on 4090 (up to 3.5x faster than BF16 tensor cores).

FSDP2 benchmark w/ torchtitan Using https://github.com/gau-nernst/torchtitan/tree/bitnet. Llama3-8B on C4, default config, 4x A100. torch==2.6.0.dev20240924

image
Model Train loss tok/s
BF16 mixed-precision baseline 4.94 2,573
BitNet (BF16 all-gather) 4.98 2,766 (+7.5%)
BitNet (2-bit all-gather) 4.97 2,876 (+12%)
BitNet (2-bit all-gather + precompute scale) 4.99 2,879 (+12%)

Note: due to the way torchtitan initialize weights, it's a bit troublesome to add extra RMSNorm layers as recommended by the paper. Thus, for benchmarks in torchtitan, I don't add extra RMSNorm.

Copy link

pytorch-bot bot commented Sep 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/930

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6ab6863 with merge base 72d2518 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 24, 2024
@gau-nernst gau-nernst marked this pull request as ready for review September 25, 2024 13:11
elif args.quantize == "int8_mixed_precision":
quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False)

elif args.quantize == "bitnet":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: this is "change model architecture and then do quantization" which is pretty different from just "quantization". For code clarity, maybe we can either have an explicit preprocessing step to be called separately, or call the arg something like rmsnorm_model_surgery_then_quantize_bitnet?


def _pack_i2_in_i8(x: Tensor):
# NOTE: this is signed integer, so we have to mask before bit-shift
return (x[:, ::4] << 6) | ((x[:, 1::4] & 0b11) << 4) | ((x[:, 2::4] & 0b11) << 2) | (x[:, 3::4] & 0b11)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

readability nit: write it out line by line with comments to make easier to understand for code readers?

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

LGTM for prototype but feel free to wait for other reviews if needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants