Skip to content

Conversation

@tobiasvanderwerff
Copy link
Contributor

@tobiasvanderwerff tobiasvanderwerff commented Oct 23, 2024

Closes #998

This PR adds BF16 support for the existing FPx CUDA kernel (here), which was originally written for FP16 (see the paper for details). Since all recent models are trained and released with BF16, having BF16 support potentially improves accuracy for FPx models.

Most important changes

  1. Use C++ templates in the kernel to allow for either FP16 or BF16 input.
  2. Change Python wrapper code for floatx/fpx to take into account BF16 inputs.
  3. Add tests for bf16 (test/test_ops.py and test/dtypes/test_floatx.py).
  4. Change dequant logic for BF16 (see FPx_FP16_Cast_4Way function). This is simply a matter of changing the bit shifting to take into account the difference in exponents bits between FP16 and BF16 (5 vs. 8, respectively). Figure 6 in the original paper contains a helpful visualization of the bit storage pattern.
  5. Change tensor core MMA instructions for BF16 (see MMA_FP16_M16N8K16 function). Also straightforward, since it only involves changing the input types of the MMA instructions.
  6. Modified exponent bias calculations (see MultScale function). This turned out to be a little more involved and required a custom solution. See the section below for more details.

More details on modified exponent bias calculations

Adapting the FP6 kernel for BF16 introduces a complication regarding the exponent bias in the dequantization step compared to the original implementation for FP16. This required me to come up with a custom solution. For context: in section 4.2.2 and 5.3.1 of the FP6 paper, they mention that the exponent for FP16 (and also equivalent for BF16) is

$E^{\text{f⁢p⁢16}} = E^{\text{f⁢p⁢x}} + \text{b⁢i⁢a⁢s}^{\text{f⁢p⁢16}} − \text{b⁢i⁢a⁢s}^{\text{f⁢p⁢x}}$

Adding the constant bias terms is computationally expensive to do during dequantization, so instead they set $E^{\text{f⁢p⁢16}} = E^{\text{f⁢p⁢x}}$ during dequantization, followed by multiplication later on with $2^{\text{b⁢i⁢a⁢s}^{\text{f⁢p⁢16}} − \text{b⁢i⁢a⁢s}^{\text{f⁢p⁢x}}}$ to get the equivalent result in a more efficient way. This works fine for FP16, since $\text{b⁢i⁢a⁢s}^{\text{fp⁢16}} = 15$ and so $2^{\text{b⁢i⁢a⁢s}^{\text{f⁢p⁢16}} − \text{b⁢i⁢a⁢s}^{\text{f⁢p⁢x}}} = 2^{15-3} = 2^{12}$ for FP6, which fits into a single 32-bit integer value.

Translating this to BF16 however, is a little more tricky. Since $\text{b⁢i⁢a⁢s}^{\text{bf⁢16}} = 127$ (see Wikipedia), this would amount to multiplication by $2^{\text{b⁢i⁢a⁢s}^{\text{bf⁢16}} − \text{b⁢i⁢a⁢s}^{\text{f⁢p⁢x}}} = 2^{127-3} = 2^{124}$, which is too large to fit into a 32-bit or even 64-bit number. To address this, I experimented with a few approaches, which I highlight below.

Approach 1: Type punning (current choice)

EDIT: The solution below is now simplified by using the CUDA ldexpf function. Thanks to @gau-nernst for pointing me in the right direction!

$2^{124}$ is too large to fit into an integer, but it can fit into a 32-bit floating point number. Since the exponent for IEEE 754 floating point has 8 bits, it can represent a maximum exponent of $2^{127}$ (see Wikipedia). So, we can represent $2^{124}$ by directly setting the exponent bits and constructing a float, like so:

union {
    uint32_t u32;
    float f;
} tmp;
tmp.u32 = (124 + 127) << 23;  // 127=exponent bias, 23=mantissa
// tmp.f now contains a float with sign=0, mantissa=0, exponent=124, 
// which translates to the desired floating point value of 2^124
__nv_bfloat16 result = __hmul(*BF16_1,__float2bfloat16(tmp.f));

Approach 2: Decompose the exponent bias into smaller values

A second approach is based on the insight that instead of multiplying by $2^{124}$, we can notice that $2^{124} = 2^{31} \cdot 2^{31} \cdot 2^{31} \cdot 2^{31}$. Since $2^{31}$ does fit into an ordinary 32-bit unsigned int, we can multiply 4 times by this value, like so:

__nv_bfloat16 tmp1 = *BF16_1;
const __nv_bfloat16 BIAS = __float2bfloat16(1.0f * (uint32_t(1) << (124 / 4)));
#pragma unroll
for (int i = 0; i < 4; i++) {
    tmp1 = __hmul(tmp1, BIAS);
}
__nv_bfloat16 result = tmp1;

This approach works, but I think it is less preferrable to approach 1 since it introduces 3 additional multiplications.

Benchmark

Tested on:

  • A100 GPU
  • Torch version: 2.6.0.dev20241022
  • CUDA version: 12.2
m k n fp6-fp16 latency (ms) fp16 latency (ms) speedup fp16 correct fp16 fp6-bf16 latency (ms) bf16 latency (ms) speedup bf16 correct bf16
1 8192 8192 47.8153 107.905 2.25671 1 48.0822 108.038 2.24693 1
1 8192 10240 56.5186 129.677 2.29441 1 56.7016 129.925 2.29138 1
1 8192 57344 261.31 679.021 2.59853 1 261.341 681.108 2.60621 1
1 28672 8192 142.66 340.947 2.38992 1 142.768 341.311 2.39066 1
2 8192 8192 48.1587 108.817 2.25954 1 48.5589 108.727 2.23908 1
2 8192 10240 56.8241 133.974 2.35769 1 57.2074 134.379 2.34898 1
2 8192 57344 264.475 684.25 2.58721 1 264.451 683.928 2.58622 1
2 28672 8192 143.41 338.273 2.35878 1 143.463 338.418 2.35893 1
4 8192 8192 49.0373 108.863 2.22 1 49.3801 108.82 2.20372 1
4 8192 10240 57.9573 134.769 2.32532 1 58.1178 135.255 2.32726 1
4 8192 57344 269.995 688.162 2.5488 1 270.09 687.903 2.54694 1
4 28672 8192 144.317 339.184 2.35027 1 144.298 339.239 2.35096 1
8 8192 8192 50.7631 109.538 2.15783 1 50.9435 108.694 2.13361 1
8 8192 10240 59.6324 136.478 2.28866 1 59.8107 136.604 2.28394 1
8 8192 57344 281.025 694.71 2.47206 1 281.094 694.339 2.47013 1
8 28672 8192 146.231 340.169 2.32625 1 146.197 340.209 2.32707 1
16 8192 8192 54.6856 111.104 2.0317 1 54.8701 110.281 2.00986 1
16 8192 10240 63.8146 139.737 2.18973 1 63.9122 139.169 2.1775 1
16 8192 57344 310.406 705.661 2.27335 1 310.524 705.344 2.27147 1
16 28672 8192 153.435 342.533 2.23243 1 153.635 342.03 2.22625 1
32 8192 8192 59.8124 109.968 1.83855 1 60.0498 117.015 1.94863 1
32 8192 10240 70.0894 137.461 1.96122 1 70.0611 134.311 1.91706 1
32 8192 57344 361.694 713.476 1.97259 1 361.379 747.143 2.06748 1
32 28672 8192 158.193 352.667 2.22935 1 159.116 365.284 2.29571 1
64 8192 8192 75.025 111.396 1.48478 1 75.3286 116.787 1.55037 1
64 8192 10240 86.4957 133.782 1.54669 1 86.5149 137.81 1.59291 1
64 8192 57344 496.407 717.28 1.44494 1 495.186 848.441 1.71338 1
64 28672 8192 198.356 363.766 1.8339 1 198.211 378.897 1.91158 1
128 8192 8192 116.348 134.703 1.15776 1 115.112 136.111 1.18242 1
128 8192 10240 152.802 162.002 1.06021 1 151.848 144.769 0.953382 1
128 8192 57344 829.325 813.571 0.981003 1 821.493 860.491 1.04747 1
128 28672 8192 363.332 379.417 1.04427 1 358.096 377.712 1.05478 1
256 8192 8192 232.587 176.013 0.756765 1 232.328 178.489 0.768264 1
256 8192 10240 278.047 225.67 0.811624 1 275.367 215.422 0.78231 1
256 8192 57344 1517.54 1099.87 0.72477 1 1498.89 1092.88 0.729125 1
256 28672 8192 676.377 526.53 0.778457 1 667.652 536.409 0.803426 1
512 8192 8192 455.454 338.41 0.743016 1 451.99 328.303 0.726351 1
512 8192 10240 496.543 363.535 0.732133 1 492.433 410.334 0.833279 1
512 8192 57344 2528.04 2005.78 0.793411 1 2530.73 1963.43 0.775835 1
512 28672 8192 1338.7 1013.8 0.757304 1 1321.18 983.968 0.744764 1

Final remarks

  • I've added guards that should still make the kernel compatible with SM75, but I have not confirmed this because I was testing on an SM80 GPU.
  • I've only tested this for FP6, but I think it should most likely also work for the other currently implemented FPx types.

CC @gau-nernst

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1147

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a6de35a with merge base eb1fb3a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 23, 2024
Copy link
Collaborator

@gau-nernst gau-nernst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Leave some comments for discussion.

@tobiasvanderwerff
Copy link
Contributor Author

Here are some benchmark results on Llama-2-7b-chat-hf.

Tested on:

  • A100 GPU
  • Torch version: 2.6.0.dev20241023
  • CUDA version: 12.2

Wikitext perplexity (using torchao/_models/llama/eval.py):

fp6-float16:  12.3648
fp6-bfloat16: 12.3669

FP16 perf.

python generate.py --compile --quantization fp6 --precision float16

Average tokens/sec: 111.55
Average Bandwidth: 553.13 GB/s
Peak Memory Usage: 6.69 GB
Model Size: 4.96 GB

BF16 perf.

python generate.py --compile --quantization fp6 --precision bfloat16

Average tokens/sec: 111.19
Average Bandwidth: 551.35 GB/s
Peak Memory Usage: 6.69 GB
Model Size: 4.96 GB

# doesn't seem to be the right way to check for correctness
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious. I saw that generally when BF16 is used, tolerance is quite higher than FP16. From your experience working on this, you do suspect any part of the code might result in this loss of precision? e.g. perhaps some parts are computed in BF16 instead of FP32. Or maybe it's just the way it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All I know is that BF16 has fewer bits for the fraction (mantissa) than FP16 (10 bits vs. 7 bits), so that leads to lower precision for BF16 compared to FP16. I can't think of any part of the FP6 kernel that would inherently lead to more loss of precision for BF16.

Copy link
Collaborator

@gau-nernst gau-nernst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the great work! I will request Mark and Charles to take another scan at the PR.

Edit: almost forgot. @tobiasvanderwerff Can you update the docs also? Mainly this one I think https://github.com/pytorch/ao/tree/main/torchao/dtypes/floatx. You can mention that you extended the original kernel to work with BF16 😄. I think there might be other places where we mentioned FPx kernel only works with FP16.

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to adding docs, left a few small comments - this is close to getting landed

int Split_K)
{
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
static_assert(std::is_same<InputDataType, half>::value || std::is_same<InputDataType, __nv_bfloat16>::value, "Type must be float or __nv_bfloat16");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in warning did you mean float16 instead of float

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

"r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
if constexpr (USE_BF16) {
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for myself is to add some comments explain this asm

@msaroufim msaroufim self-requested a review October 25, 2024 06:30
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to adding docs, left a few small comments - this is close to getting landed

Any check for `__CUDA_ARCH__` in `fp6_linear.cu` will always fail because `__CUDA_ARCH__` is undefined since all of the functions in `fp6_linear.cu` are host functions
@tobiasvanderwerff
Copy link
Contributor Author

tobiasvanderwerff commented Oct 29, 2024

@msaroufim @gau-nernst I updated the docs at https://github.com/pytorch/ao/tree/main/torchao/dtypes/floatx. I didn't update the benchmark results with BF16 however, because I don't have access to the same machine that the existing benchmarks results used. I'm a bit short on time today so I may have missed some additional docs that should be updated -- I'll double check this later. I also intend to do a sm75 test because that would be good to check I think. I still need to address some of your comments; I'll also do this a bit later.

@msaroufim msaroufim self-requested a review October 30, 2024 19:28
If this is not done, the kernel may still run but fail silently, leading to unexpected behavior
There are currently several ways of using `__CUDA_ARCH__` that lead to undefined behavior. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-arch for details of how `__CUDA_ARCH__` should not be used
@gau-nernst gau-nernst self-requested a review November 1, 2024 13:26
Copy link
Collaborator

@gau-nernst gau-nernst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new changes look good! Figuring out how to handle CUDA_ARCH took more efforts than expected, but now we all understand it better 😄.
I think there are some problems with CUDA CI right now. Will need to wait for that to be fixed.

@msaroufim msaroufim merged commit 8c07d22 into pytorch:main Nov 4, 2024
17 checks passed
@tobiasvanderwerff tobiasvanderwerff deleted the fp6-bf16 branch November 4, 2024 07:49
@fxmarty-amd
Copy link

cool work!

@gau-nernst gau-nernst mentioned this pull request Feb 1, 2025
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make Quant-LLM compatible with BF16

5 participants