Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] Integrate float8nocompile, an experimental feature for high performance #778

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link

@danielvegamyhre danielvegamyhre commented Jan 7, 2025

Summary

This PR contains a proof-of-concept integration of an experimental feature I've been working on in torchao: float8nocompile (official name TBD, naming things is hard!).

It is an implementation of float8 conversion with tensorwise dynamic scaling that uses handwritten Triton kernels to achieve high performance, rather than requiring torch.compile.

Benchmarking training performance

Model: Llama3 8b on 4 H100s with batch size of 1, seq len of 4096.

TL;DR is the prototype slightly beats the baseline we want to match (prod float8 + compiling only linear layers) for both full AC and no AC.

No AC (seq len 4096) - 8 H100s

Configuration TFLOPS Tokens/sec Peak memory usage
bfloat16, eager mode 0.2339 2.31E+14 231.37
float8nocompile, eager mode 0.2504 2.48E+14 247.65
bfloat16, torch.compile 0.2618 2.59E+14 258.92
float8, torch.compile only nn.Linear layers 0.2460 2.43E+14 243.27
float8, torch.compile full model 0.3002 2.97E+14 296.92

Full AC (seq len 4096) - 8 GPUs

Configuration TFLOPS Tokens/sec Peak memory usage
bfloat16, eager mode 0.2343 2.32E+14 231.74
float8nocompile, eager mode 0.2494 2.47E+14 246.62
bfloat16, torch.compile 0.2643 2.61E+14 261.43
float8, torch.compile only nn.Linear layers 0.2460 2.43E+14 243.29
float8, torch.compile full model 0.3005 2.97E+14 297.23

Benchmarking single linear layer forward+backward performance

Tested used a single linear layer of size (4096,4096) with different input sizes.

Performance benchmarks show the float8nocompile implementation is beating torch.compile by 1.72-4.45% depending on the input tensor size.

input_shape    kernel_algo                 high_precision_dtype      eager_time    compiled_time    float8nocompile
-------------  --------------------------  ----------------------  ------------  ---------------  -----------------
(16, 4096)     KernelAlgorithm.ATOMIC_MAX  torch.bfloat16               649.218          394.725            386.469
(256, 4096)    KernelAlgorithm.ATOMIC_MAX  torch.bfloat16               685.783          420.743            408.137
(4096, 4096)   KernelAlgorithm.ATOMIC_MAX  torch.bfloat16              1829.13          1053.64             977.858
(65536, 4096)  KernelAlgorithm.ATOMIC_MAX  torch.bfloat16             21554.2          12369.7            10813.3

Note: this PR depends on this stack of PRs being merged into torchao, and those changes being included into a release (which the user installs).

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 7, 2025
@danielvegamyhre
Copy link
Author

@vkuzo here is the PoC of the torchtitan + float8nocompile integration, and the training performance benchmarking results

@vkuzo
Copy link
Contributor

vkuzo commented Jan 7, 2025

I think a good way to go here is to mark this PR "not for land" for now

Benchmarking mean MFU during training run

Here is the data I think is important:

  1. in addition to mfu, also report tokens per second and peak memory usage as top line metrics
  2. report the metrics for the following experiments:
    2a. bfloat16, eager mode
    2b. float8nocompile, eager mode
    2c. bfloat16, torch.compile
    2d. float8, torch.compile, without float8 FSDP all-gather

@danielvegamyhre danielvegamyhre changed the title [PoC] Integrate float8nocompile, an experimental feature for high performance [Not for land] Integrate float8nocompile, an experimental feature for high performance Jan 7, 2025
@awgu
Copy link
Contributor

awgu commented Jan 7, 2025

I think we should take care to mention/keep in mind that the MFU is with respect to peak bf16 TFLOPS. Direct comparison of TFLOPS might make more sense when comparing bf16 vs. fp8 runs since the fp8 runs are not actually doing every computation in fp8 either (e.g. SDPA or final linear).

@danielvegamyhre
Copy link
Author

I think we should take care to mention/keep in mind that the MFU is with respect to peak bf16 TFLOPS. Direct comparison of TFLOPS might make more sense when comparing bf16 vs. fp8 runs since the fp8 runs are not actually doing every computation in fp8 either (e.g. SDPA or final linear).

Makes sense, is there an existing way to log TFLOPS/sec during training w/ torchtitan? Searching around I don't see one

@awgu
Copy link
Contributor

awgu commented Jan 8, 2025

@danielvegamyhre I think the TFLOPS is the same as MFU without the peak TFLOPS denominator (and extra factor of 100):

mfu = 100 * num_flop_per_token * tps / gpu_peak_flops

In other words, it is just num_flop_per_token * tps. So you can convert your MFU numbers back to TFLOPS by multiplying by gpu_peak_flops / 100.

@danielvegamyhre
Copy link
Author

@danielvegamyhre I think the TFLOPS is the same as MFU without the peak TFLOPS denominator (and extra factor of 100):

mfu = 100 * num_flop_per_token * tps / gpu_peak_flops

In other words, it is just num_flop_per_token * tps. So you can convert your MFU numbers back to TFLOPS by multiplying by gpu_peak_flops / 100.

ah of course, thanks! Updated the PR description to include TFLOPS instead of MFU

@vkuzo
Copy link
Contributor

vkuzo commented Jan 8, 2025

based on the results shared so far, I think it would be interesting to add one additional experiment branch: production float8 + torch.compile on just the torch.nn.Linear layers. I have some old unlanded code with an example of how to apply this here: #661 . If we structure the handwritten kernels right for the chosen AC strategy, we should be able to match the level of performance in that setup.

@danielvegamyhre
Copy link
Author

based on the results shared so far, I think it would be interesting to add one additional experiment branch: production float8 + torch.compile on just the torch.nn.Linear layers. I have some old unlanded code with an example of how to apply this here: #661 . If we structure the handwritten kernels right for the chosen AC strategy, we should be able to match the level of performance in that setup.

@vkuzo I updated the PR description with benchmarks comparing against prod float8 + compiling linear layers only. Also included benchmarks for no AC vs full AC.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants